From 168baa47169eeb499c35f16fb0fdbc081929c8f9 Mon Sep 17 00:00:00 2001 From: Henry Avetisyan Date: Tue, 30 May 2023 15:25:34 -0700 Subject: [PATCH] support athenz as oidc provider for aws iam (#2190) Signed-off-by: Henry Avetisyan Co-authored-by: Henry Avetisyan --- clients/go/zts/client.go | 4 +- clients/go/zts/zts_schema.go | 1 + .../java/com/yahoo/athenz/zts/ZTSClient.java | 2 +- .../athenz/zts/ZTSRDLGeneratedClient.java | 5 +- .../com/yahoo/athenz/zts/ZTSClientTest.java | 2 +- .../yahoo/athenz/zts/ZTSRDLClientMock.java | 3 +- containers/jetty/conf/athenz.properties | 10 + .../yahoo/athenz/container/AthenzConsts.java | 1 + .../container/AthenzJettyContainer.java | 60 ++-- .../container/AthenzJettyContainerTest.java | 33 +- .../java/com/yahoo/athenz/zts/ZTSSchema.java | 1 + core/zts/src/main/rdl/OAuth.rdli | 3 +- libs/go/athenzutils/idtoken.go | 4 +- servers/zts/conf/zts.properties | 12 +- .../java/com/yahoo/athenz/zts/ZTSConsts.java | 2 + .../java/com/yahoo/athenz/zts/ZTSHandler.java | 2 +- .../java/com/yahoo/athenz/zts/ZTSImpl.java | 77 +++-- .../com/yahoo/athenz/zts/ZTSResources.java | 5 +- .../com/yahoo/athenz/zts/ZTSImplTest.java | 291 ++++++++++++++---- utils/zts-idtoken/zts-idtoken.go | 11 +- 20 files changed, 404 insertions(+), 125 deletions(-) diff --git a/clients/go/zts/client.go b/clients/go/zts/client.go index 9694016f011..5badd088034 100644 --- a/clients/go/zts/client.go +++ b/clients/go/zts/client.go @@ -1156,9 +1156,9 @@ func (client ZTSClient) PostAccessTokenRequest(request AccessTokenRequest) (*Acc } } -func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32, output SimpleName) (*OIDCResponse, string, error) { +func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32, output SimpleName, roleInAudClaim *bool) (*OIDCResponse, string, error) { var data *OIDCResponse - url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime), encodeStringParam("output", string(output), "")) + url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime), encodeStringParam("output", string(output), ""), encodeOptionalBoolParam("roleInAudClaim", roleInAudClaim)) resp, err := client.httpGet(url, nil) if err != nil { return nil, "", err diff --git a/clients/go/zts/zts_schema.go b/clients/go/zts/zts_schema.go index 2fbbd02d249..39a2eaa1c3c 100644 --- a/clients/go/zts/zts_schema.go +++ b/clients/go/zts/zts_schema.go @@ -1043,6 +1043,7 @@ func init() { mGetOIDCResponse.Input("fullArn", "Bool", false, "fullArn", "", true, false, "flag to indicate to use full arn in group claim (e.g. sports:role.deployer instead of deployer)") mGetOIDCResponse.Input("expiryTime", "Int32", false, "expiryTime", "", true, nil, "optional expiry period specified in seconds") mGetOIDCResponse.Input("output", "SimpleName", false, "output", "", true, nil, "optional output format of json") + mGetOIDCResponse.Input("roleInAudClaim", "Bool", false, "roleInAudClaim", "", true, false, "flag to indicate to include role name in the audience claim only if we have a single role in response") mGetOIDCResponse.Output("location", "String", "Location", false, "return location header with id token") mGetOIDCResponse.Auth("", "", true, "") mGetOIDCResponse.Exception("BAD_REQUEST", "ResourceError", "") diff --git a/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java b/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java index 09ad5781ebe..b6c9e6b6e1c 100644 --- a/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java +++ b/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java @@ -3168,7 +3168,7 @@ public OIDCResponse getIDToken(String responseType, String clientId, String redi try { Map> responseHeaders = new HashMap<>(); oidcResponse = ztsClient.getOIDCResponse(responseType, clientId, redirectUri, scope, - state, Crypto.randomSalt(), keyType, fullArn, expiryTime, "json", responseHeaders); + state, Crypto.randomSalt(), keyType, fullArn, expiryTime, "json", false, responseHeaders); } catch (ResourceException ex) { diff --git a/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSRDLGeneratedClient.java b/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSRDLGeneratedClient.java index e92142f5f6b..3f6f4f8fbb0 100644 --- a/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSRDLGeneratedClient.java +++ b/clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSRDLGeneratedClient.java @@ -930,7 +930,7 @@ public AccessTokenResponse postAccessTokenRequest(String request) throws URISynt } } - public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, java.util.Map> headers) throws URISyntaxException, IOException { + public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim, java.util.Map> headers) throws URISyntaxException, IOException { UriTemplateBuilder uriTemplateBuilder = new UriTemplateBuilder(baseUrl, "/oauth2/auth"); URIBuilder uriBuilder = new URIBuilder(uriTemplateBuilder.getUri()); if (responseType != null) { @@ -963,6 +963,9 @@ public OIDCResponse getOIDCResponse(String responseType, String clientId, String if (output != null) { uriBuilder.setParameter("output", output); } + if (roleInAudClaim != null) { + uriBuilder.setParameter("roleInAudClaim", String.valueOf(roleInAudClaim)); + } HttpUriRequest httpUriRequest = RequestBuilder.get() .setUri(uriBuilder.build()) .build(); diff --git a/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSClientTest.java b/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSClientTest.java index f49062a3df3..8ddc6d5d4b0 100644 --- a/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSClientTest.java +++ b/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSClientTest.java @@ -2208,7 +2208,7 @@ public void testGetAWSTemporaryCredentialsException() { @Test public void testHostnameVerifierSupport() { - ZTSRDLGeneratedClientMock client = new ZTSRDLGeneratedClientMock("http://localhost:4080", (HostnameVerifier) null); + ZTSRDLGeneratedClientMock client = new ZTSRDLGeneratedClientMock("http://localhost:4080", null); HostnameVerifier hostnameVerifier = client.getHostnameVerifier(); assertTrue(hostnameVerifier == null || hostnameVerifier instanceof org.apache.http.conn.ssl.DefaultHostnameVerifier); diff --git a/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSRDLClientMock.java b/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSRDLClientMock.java index f7572b72845..ce974f24bea 100644 --- a/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSRDLClientMock.java +++ b/clients/java/zts/src/test/java/com/yahoo/athenz/zts/ZTSRDLClientMock.java @@ -241,7 +241,8 @@ public AccessTokenResponse postAccessTokenRequest(String request) { @Override public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, - String output, Map> headers) throws URISyntaxException, IOException { + String output, Boolean roleInAudClaim, Map> headers) + throws URISyntaxException, IOException { // some exception test cases based on the state value if (state != null) { diff --git a/containers/jetty/conf/athenz.properties b/containers/jetty/conf/athenz.properties index dd098d29222..94a8fe3f757 100644 --- a/containers/jetty/conf/athenz.properties +++ b/containers/jetty/conf/athenz.properties @@ -20,6 +20,16 @@ athenz.port=0 # ports are specified, https will be selected for the protocol. #athenz.status_port= +# Port for handling OIDC requests. If different than the configured +# https port, then the server will create a separate connector +# to handle the oidc requests only. This includes issuing id +# tokens, returning public keys and openid configuration details. +# All other requests on this port will be rejected. This is useful +# when you want to integrate with another component that requires +# the service to run on a specific port - e.g. AWS IAM OIDC provider +# requires it to run on port 443 only. +#athenz.oidc_port= + # Set the number of days before rotated access log files are deleted #athenz.access_log_retain_days=31 diff --git a/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzConsts.java b/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzConsts.java index c2ba3ddef33..f282d899111 100644 --- a/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzConsts.java +++ b/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzConsts.java @@ -82,6 +82,7 @@ public final class AthenzConsts { public static final String ATHENZ_PROP_HTTP_PORT = "athenz.port"; public static final String ATHENZ_PROP_HTTPS_PORT = "athenz.tls_port"; + public static final String ATHENZ_PROP_OIDC_PORT = "athenz.oidc_port"; public static final String ATHENZ_PROP_STATUS_PORT = "athenz.status_port"; public static final int ATHENZ_HTTPS_PORT_DEFAULT = 4443; diff --git a/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzJettyContainer.java b/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzJettyContainer.java index 1cb7c5b83f8..4963d711ab7 100644 --- a/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzJettyContainer.java +++ b/containers/jetty/src/main/java/com/yahoo/athenz/container/AthenzJettyContainer.java @@ -474,9 +474,17 @@ void addHTTPSConnector(HttpConfiguration httpsConfig, int httpsPort, boolean pro } } } - + + HttpConfiguration getHttpsConfig(HttpConfiguration httpConfig, int httpsPort, boolean sniRequired, boolean sniHostCheck) { + HttpConfiguration httpsConfig = new HttpConfiguration(httpConfig); + httpsConfig.setSecureScheme("https"); + httpsConfig.setSecurePort(httpsPort); + httpsConfig.addCustomizer(new SecureRequestCustomizer(sniRequired, sniHostCheck, -1L, false)); + return httpsConfig; + } + public void addHTTPConnectors(HttpConfiguration httpConfig, int httpPort, int httpsPort, - int statusPort) { + int oidcPort, int statusPort) { int idleTimeout = Integer.parseInt( System.getProperty(AthenzConsts.ATHENZ_PROP_IDLE_TIMEOUT, "30000")); @@ -500,38 +508,34 @@ public void addHTTPConnectors(HttpConfiguration httpConfig, int httpPort, int ht connectionLogger = jettyConnectionLoggerFactory.create(); } + boolean sniRequired = Boolean.parseBoolean( + System.getProperty(AthenzConsts.ATHENZ_PROP_SNI_REQUIRED, "false")); + boolean sniHostCheck = Boolean.parseBoolean( + System.getProperty(AthenzConsts.ATHENZ_PROP_SNI_HOSTCHECK, "true")); + boolean needClientAuth = Boolean.parseBoolean( + System.getProperty(AthenzConsts.ATHENZ_PROP_CLIENT_AUTH, "false")); + // HTTPS Connector if (httpsPort > 0) { + HttpConfiguration httpsConfig = getHttpsConfig(httpConfig, httpsPort, sniRequired, sniHostCheck); + addHTTPSConnector(httpsConfig, httpsPort, proxyProtocol, listenHost, + idleTimeout, needClientAuth, connectionLogger); + } - boolean sniRequired = Boolean.parseBoolean( - System.getProperty(AthenzConsts.ATHENZ_PROP_SNI_REQUIRED, "false")); - boolean sniHostCheck = Boolean.parseBoolean( - System.getProperty(AthenzConsts.ATHENZ_PROP_SNI_HOSTCHECK, "true")); - - HttpConfiguration httpsConfig = new HttpConfiguration(httpConfig); - httpsConfig.setSecureScheme("https"); - httpsConfig.setSecurePort(httpsPort); - httpsConfig.addCustomizer(new SecureRequestCustomizer(sniRequired, sniHostCheck, -1L, false)); - - boolean needClientAuth = Boolean.parseBoolean( - System.getProperty(AthenzConsts.ATHENZ_PROP_CLIENT_AUTH, "false")); + // OIDC Connector - only if it's different from HTTPS - addHTTPSConnector(httpsConfig, httpsPort, proxyProtocol, listenHost, + if (oidcPort > 0 && oidcPort != httpsPort) { + HttpConfiguration httpsConfig = getHttpsConfig(httpConfig, oidcPort, sniRequired, sniHostCheck); + addHTTPSConnector(httpsConfig, oidcPort, proxyProtocol, listenHost, idleTimeout, needClientAuth, connectionLogger); } - + // Status Connector - only if it's different from HTTP/HTTPS if (statusPort > 0 && statusPort != httpPort && statusPort != httpsPort) { - if (httpsPort > 0) { - - HttpConfiguration httpsConfig = new HttpConfiguration(httpConfig); - httpsConfig.setSecureScheme("https"); - httpsConfig.setSecurePort(httpsPort); - httpsConfig.addCustomizer(new SecureRequestCustomizer(false, false, -1L, false)); - + HttpConfiguration httpsConfig = getHttpsConfig(httpConfig, httpsPort, false, false); addHTTPSConnector(httpsConfig, statusPort, false, listenHost, idleTimeout, false, connectionLogger); } else if (httpPort > 0) { addHTTPConnector(httpConfig, statusPort, false, listenHost, idleTimeout); @@ -592,7 +596,11 @@ public static AthenzJettyContainer createJettyContainer() { AthenzConsts.ATHENZ_HTTP_PORT_DEFAULT); int httpsPort = ConfigProperties.getPortNumber(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, AthenzConsts.ATHENZ_HTTPS_PORT_DEFAULT); - + + // extract the port for oidc requests if one is configured + + int oidcPort = ConfigProperties.getPortNumber(AthenzConsts.ATHENZ_PROP_OIDC_PORT, 0); + // for status port we'll use the protocol specified for the regular http // port. if both http and https are provided then https will be picked // it could also be either one of the values specified as well @@ -604,14 +612,14 @@ public static AthenzJettyContainer createJettyContainer() { AthenzJettyContainer container = new AthenzJettyContainer(); container.setBanner("http://" + serverHostName + " http port: " + httpPort + " https port: " + httpsPort + " status port: " + - statusPort); + statusPort + " oidc port: " + oidcPort); int maxThreads = Integer.parseInt(System.getProperty(AthenzConsts.ATHENZ_PROP_MAX_THREADS, Integer.toString(AthenzConsts.ATHENZ_HTTP_MAX_THREADS))); container.createServer(maxThreads); HttpConfiguration httpConfig = container.newHttpConfiguration(); - container.addHTTPConnectors(httpConfig, httpPort, httpsPort, statusPort); + container.addHTTPConnectors(httpConfig, httpPort, httpsPort, oidcPort, statusPort); container.addServletHandlers(serverHostName); container.addRequestLogHandler(); diff --git a/containers/jetty/src/test/java/com/yahoo/athenz/container/AthenzJettyContainerTest.java b/containers/jetty/src/test/java/com/yahoo/athenz/container/AthenzJettyContainerTest.java index 22e45210b10..43f8e807e55 100644 --- a/containers/jetty/src/test/java/com/yahoo/athenz/container/AthenzJettyContainerTest.java +++ b/containers/jetty/src/test/java/com/yahoo/athenz/container/AthenzJettyContainerTest.java @@ -215,17 +215,20 @@ public void testHttpConnectorsBoth() { container.createServer(100); HttpConfiguration httpConfig = container.newHttpConfiguration(); - container.addHTTPConnectors(httpConfig, 8081, 8082, 0); + container.addHTTPConnectors(httpConfig, 8081, 8082, 443, 0); Server server = container.getServer(); Connector[] connectors = server.getConnectors(); - assertEquals(connectors.length, 2); + assertEquals(connectors.length, 3); assertEquals(connectors[0].getIdleTimeout(), 10001); assertTrue(connectors[0].getProtocols().contains("http/1.1")); assertTrue(connectors[1].getProtocols().contains("http/1.1")); assertTrue(connectors[1].getProtocols().contains("ssl")); + + assertTrue(connectors[2].getProtocols().contains("http/1.1")); + assertTrue(connectors[2].getProtocols().contains("ssl")); } @Test @@ -247,7 +250,7 @@ public void testNonExistantKeyStore() { HttpConfiguration httpConfig = container.newHttpConfiguration(); try { // This should throw - container.addHTTPConnectors(httpConfig, 8081, 8082, 0); + container.addHTTPConnectors(httpConfig, 8081, 8082, 0, 0); fail(); } catch (IllegalArgumentException exception) { // as expected @@ -270,7 +273,7 @@ public void testHttpConnectorsHttpsOnly() { container.createServer(100); HttpConfiguration httpConfig = container.newHttpConfiguration(); - container.addHTTPConnectors(httpConfig, 0, 8082, 0); + container.addHTTPConnectors(httpConfig, 0, 8082, 0, 0); Server server = container.getServer(); Connector[] connectors = server.getConnectors(); @@ -296,7 +299,7 @@ public void testHttpConnectorsHttpOnly() { container.createServer(100); HttpConfiguration httpConfig = container.newHttpConfiguration(); - container.addHTTPConnectors(httpConfig, 8081, 0, 0); + container.addHTTPConnectors(httpConfig, 8081, 0, 0, 0); Server server = container.getServer(); Connector[] connectors = server.getConnectors(); @@ -447,7 +450,7 @@ public void testInitContainerValidPorts() { Server server = container.getServer(); Connector[] connectors = server.getConnectors(); - assertEquals(connectors.length, 2); + assertEquals(connectors.length, 3); assertTrue(connectors[0].getProtocols().contains("http/1.1")); @@ -460,6 +463,7 @@ public void testInitContainerOnlyHTTPSPort() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "0"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "4443"); + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "8443"); System.setProperty("yahoo.zms.debug.user_authority", "true"); AthenzJettyContainer container = AthenzJettyContainer.createJettyContainer(); @@ -467,10 +471,13 @@ public void testInitContainerOnlyHTTPSPort() { Server server = container.getServer(); Connector[] connectors = server.getConnectors(); - assertEquals(connectors.length, 1); + assertEquals(connectors.length, 2); assertTrue(connectors[0].getProtocols().contains("http/1.1")); assertTrue(connectors[0].getProtocols().contains("ssl")); + + assertTrue(connectors[1].getProtocols().contains("http/1.1")); + assertTrue(connectors[1].getProtocols().contains("ssl")); } @Test @@ -478,6 +485,7 @@ public void testInitContainerOnlyHTTPPort() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "4080"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "0"); + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "0"); AthenzJettyContainer container = AthenzJettyContainer.createJettyContainer(); assertNotNull(container); @@ -495,18 +503,22 @@ public void testInitContainerInvalidHTTPPort() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "-10"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "4443"); - + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "443"); + AthenzJettyContainer container = AthenzJettyContainer.createJettyContainer(); assertNotNull(container); Server server = container.getServer(); Connector[] connectors = server.getConnectors(); - assertEquals(connectors.length, 2); + assertEquals(connectors.length, 3); assertTrue(connectors[0].getProtocols().contains("http/1.1")); assertTrue(connectors[1].getProtocols().contains("http/1.1")); assertTrue(connectors[1].getProtocols().contains("ssl")); + + assertTrue(connectors[2].getProtocols().contains("http/1.1")); + assertTrue(connectors[2].getProtocols().contains("ssl")); } @Test @@ -514,6 +526,7 @@ public void testInitContainerInvalidHTTPSPort() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "4080"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "-10"); + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "0"); AthenzJettyContainer container = AthenzJettyContainer.createJettyContainer(); assertNotNull(container); @@ -533,6 +546,7 @@ public void testInitContainerOptionalFeatures() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "4080"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "4443"); + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "4443"); System.setProperty(AthenzConsts.ATHENZ_PROP_DEBUG, "true"); System.setProperty(AthenzConsts.ATHENZ_PROP_GZIP_SUPPORT, "true"); System.setProperty(AthenzConsts.ATHENZ_PROP_HEALTH_CHECK_URI_LIST, "/status.html"); @@ -559,6 +573,7 @@ public void testInitContainerStatusPortHTTPS() { System.setProperty(AthenzConsts.ATHENZ_PROP_HTTP_PORT, "4080"); System.setProperty(AthenzConsts.ATHENZ_PROP_HTTPS_PORT, "4443"); + System.setProperty(AthenzConsts.ATHENZ_PROP_OIDC_PORT, "8443"); System.setProperty(AthenzConsts.ATHENZ_PROP_STATUS_PORT, "4444"); AthenzJettyContainer container = AthenzJettyContainer.createJettyContainer(); diff --git a/core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java b/core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java index e3e79d866bb..386ac0ac87a 100644 --- a/core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java +++ b/core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java @@ -1060,6 +1060,7 @@ private static Schema build() { .queryParam("fullArn", "fullArn", "Bool", false, "flag to indicate to use full arn in group claim (e.g. sports:role.deployer instead of deployer)") .queryParam("expiryTime", "expiryTime", "Int32", null, "optional expiry period specified in seconds") .queryParam("output", "output", "SimpleName", null, "optional output format of json") + .queryParam("roleInAudClaim", "roleInAudClaim", "Bool", false, "flag to indicate to include role name in the audience claim only if we have a single role in response") .output("Location", "location", "String", "return location header with id token") .auth("", "", true) .expected("OK") diff --git a/core/zts/src/main/rdl/OAuth.rdli b/core/zts/src/main/rdl/OAuth.rdli index 6424ee88171..65fb31df69e 100644 --- a/core/zts/src/main/rdl/OAuth.rdli +++ b/core/zts/src/main/rdl/OAuth.rdli @@ -44,7 +44,7 @@ resource AccessTokenResponse POST "/oauth2/token" { } // Fetch OAuth OpenID Connect ID Token -resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={clientId}&redirect_uri={redirectUri}&scope={scope}&state={state}&nonce={nonce}&keyType={keyType}&fullArn={fullArn}&expiryTime={expiryTime}&output={output}" { +resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={clientId}&redirect_uri={redirectUri}&scope={scope}&state={state}&nonce={nonce}&keyType={keyType}&fullArn={fullArn}&expiryTime={expiryTime}&output={output}&roleInAudClaim={roleInAudClaim}" { String responseType; //response type - currently only supporting id tokens - id_token ServiceName clientId; //client id - must be valid athenz service identity name String redirectUri; //redirect uri for the response @@ -55,6 +55,7 @@ resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={ Bool fullArn (optional, default=false); //flag to indicate to use full arn in group claim (e.g. sports:role.deployer instead of deployer) Int32 expiryTime (optional); //optional expiry period specified in seconds SimpleName output (optional); //optional output format of json + Bool roleInAudClaim (optional, default=false); //flag to indicate to include role name in the audience claim only if we have a single role in response String location (header="Location", out); //return location header with id token authenticate; expected OK, FOUND; diff --git a/libs/go/athenzutils/idtoken.go b/libs/go/athenzutils/idtoken.go index 4b41b0fae04..29cefa7ffda 100644 --- a/libs/go/athenzutils/idtoken.go +++ b/libs/go/athenzutils/idtoken.go @@ -14,7 +14,7 @@ import ( "time" ) -func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType string, fullArn *bool, proxy bool, expireTime *int32) (string, error) { +func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType string, fullArn *bool, proxy bool, expireTime *int32, roleInAudClaim *bool) (string, error) { client, err := ZtsClient(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, proxy) if err != nil { @@ -23,7 +23,7 @@ func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redi client.DisableRedirect = true // request an id token - _, location, err := client.GetOIDCResponse("id_token", zts.ServiceName(clientId), redirectUri, scope, zts.EntityName(state), zts.EntityName(nonce), zts.SimpleName(keyType), fullArn, expireTime, "") + _, location, err := client.GetOIDCResponse("id_token", zts.ServiceName(clientId), redirectUri, scope, zts.EntityName(state), zts.EntityName(nonce), zts.SimpleName(keyType), fullArn, expireTime, "", roleInAudClaim) if err != nil { return "", err } diff --git a/servers/zts/conf/zts.properties b/servers/zts/conf/zts.properties index 12bca1564d2..aac35a592c7 100644 --- a/servers/zts/conf/zts.properties +++ b/servers/zts/conf/zts.properties @@ -593,10 +593,18 @@ athenz.zts.cert_signer_factory_class=com.yahoo.athenz.zts.cert.impl.SelfCertSign # To specify the issuer field in the OpenID configuration metadata object. # This is also used to generate the JWKS URI in the configuration object, -# so it must be the full https scheme endpoint for the server including the port. -# the uri must not contain a trailing /. +# so it must be the full https scheme endpoint for the server including the +# port unless the port is 443. The uri must not contain a trailing /. #athenz.zts.openid_issuer= +# If the administrator has configured a separate port for only OIDC +# requests, then this setting specifies the issuer field for the OpenID +# configuration metadata object. This is also used to generate the JWKS +# URI in the configuration object, so it must be the full https scheme +# endpoint for the server including the port unless the port is 443. +# The uri must not contain a trailing /. +#athenz.zts.oidc_port_issuer= + # The path to the trust store file that contains CA certificates # trusted by the ZTS Provider Client (this client is used to validate # instance register and refresh requests) diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java index 917db520c88..0bc1b518302 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSConsts.java @@ -24,6 +24,7 @@ public final class ZTSConsts { public static final String ZTS_PROP_USER_DOMAIN_ALIAS = "athenz.user_domain_alias"; public static final String ZTS_PROP_HTTP_PORT = "athenz.port"; public static final String ZTS_PROP_HTTPS_PORT = "athenz.tls_port"; + public static final String ZTS_PROP_OIDC_PORT = "athenz.oidc_port"; public static final String ZTS_PROP_STATUS_PORT = "athenz.status_port"; public static final String ZTS_PROP_ROOT_DIR = "athenz.zts.root_dir"; @@ -65,6 +66,7 @@ public final class ZTSConsts { public static final String ZTS_PROP_OAUTH_ISSUER = "athenz.zts.oauth_issuer"; public static final String ZTS_PROP_OAUTH_OPENID_SCOPE = "athenz.zts.oauth_openid_scope"; public static final String ZTS_PROP_OPENID_ISSUER = "athenz.zts.openid_issuer"; + public static final String ZTS_PROP_OIDC_PORT_ISSUER = "athenz.zts.oidc_port_issuer"; public static final String ZTS_PROP_REDIRECT_URI_SUFFIX = "athenz.zts.redirect_uri_suffix"; public static final String ZTS_PROP_CERTSIGN_BASE_URI = "athenz.zts.certsign_base_uri"; diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSHandler.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSHandler.java index a47c66889ff..a3191857db6 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSHandler.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSHandler.java @@ -39,7 +39,7 @@ public interface ZTSHandler { OAuthConfig getOAuthConfig(ResourceContext context); JWKList getJWKList(ResourceContext context, Boolean rfc); AccessTokenResponse postAccessTokenRequest(ResourceContext context, String request); - Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output); + Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim); RoleCertificate postRoleCertificateRequestExt(ResourceContext context, RoleCertificateRequest req); RoleAccess getRolesRequireRoleCert(ResourceContext context, String principal); Workloads getWorkloadsByService(ResourceContext context, String domainName, String serviceName); diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java index 2801f54eeaa..3cb8d177725 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java @@ -146,6 +146,7 @@ public class ZTSImpl implements KeyStore, ZTSHandler { protected List authFreeUriList = null; protected int httpPort; protected int httpsPort; + protected int oidcPort; protected int statusPort; protected boolean statusCertSigner = false; protected Status successServerStatus = null; @@ -161,8 +162,10 @@ public class ZTSImpl implements KeyStore, ZTSHandler { protected AuthzDetailsEntityList systemAuthzDetails = null; protected ObjectMapper jsonMapper; protected OpenIDConfig openIDConfig; + protected OpenIDConfig oidcPortConfig; protected OAuthConfig oauthConfig; protected String ztsOpenIDIssuer; + protected String ztsOIDCPortIssuer; protected String redirectUriSuffix; protected Info serverInfo = null; protected AthenzJWKConfig jwkConfig; @@ -401,15 +404,21 @@ protected ServiceIdentity sysAuthService(String serviceName) { return serviceIdentity; } + private OpenIDConfig createOpenidIDConfigObject(final String issuer) { + OpenIDConfig config = new OpenIDConfig(); + config.setIssuer(issuer); + config.setJwks_uri(issuer + "/oauth2/keys?rfc=true"); + config.setAuthorization_endpoint(issuer + "/oauth2/auth"); + config.setSubject_types_supported(Collections.singletonList(ZTSConsts.ZTS_OPENID_SUBJECT_TYPE_PUBLIC)); + config.setResponse_types_supported(Collections.singletonList(ZTSConsts.ZTS_OPENID_RESPONSE_IT_ONLY)); + config.setId_token_signing_alg_values_supported(getSupportedSigningAlgValues()); + return config; + } + private void setupMetaConfigObjects() { - openIDConfig = new OpenIDConfig(); - openIDConfig.setIssuer(ztsOpenIDIssuer); - openIDConfig.setJwks_uri(ztsOpenIDIssuer + "/oauth2/keys?rfc=true"); - openIDConfig.setAuthorization_endpoint(ztsOpenIDIssuer + "/oauth2/auth"); - openIDConfig.setSubject_types_supported(Collections.singletonList(ZTSConsts.ZTS_OPENID_SUBJECT_TYPE_PUBLIC)); - openIDConfig.setResponse_types_supported(Collections.singletonList(ZTSConsts.ZTS_OPENID_RESPONSE_IT_ONLY)); - openIDConfig.setId_token_signing_alg_values_supported(getSupportedSigningAlgValues()); + openIDConfig = createOpenidIDConfigObject(ztsOpenIDIssuer); + oidcPortConfig = createOpenidIDConfigObject(ztsOIDCPortIssuer); oauthConfig = new OAuthConfig(); oauthConfig.setIssuer(ztsOpenIDIssuer); @@ -497,6 +506,7 @@ void loadConfigurationSettings() { httpsPort = ConfigProperties.getPortNumber(ZTSConsts.ZTS_PROP_HTTPS_PORT, ZTSConsts.ZTS_HTTPS_PORT_DEFAULT); statusPort = ConfigProperties.getPortNumber(ZTSConsts.ZTS_PROP_STATUS_PORT, 0); + oidcPort = ConfigProperties.getPortNumber(ZTSConsts.ZTS_PROP_OIDC_PORT, 0); successServerStatus = new Status().setCode(ResourceException.OK).setMessage("OK"); @@ -620,6 +630,7 @@ void loadConfigurationSettings() { ztsOAuthIssuer = System.getProperty(ZTSConsts.ZTS_PROP_OAUTH_ISSUER, serverHostName); ztsOpenIDIssuer = System.getProperty(ZTSConsts.ZTS_PROP_OPENID_ISSUER, ztsOAuthIssuer); + ztsOIDCPortIssuer = System.getProperty(ZTSConsts.ZTS_PROP_OIDC_PORT_ISSUER, ztsOpenIDIssuer); redirectUriSuffix = System.getProperty(ZTSConsts.ZTS_PROP_REDIRECT_URI_SUFFIX); // set up our health check file @@ -1932,12 +1943,12 @@ String getQueryLogData(final String request) { @Override public Response getOIDCResponse(ResourceContext ctx, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, - Integer timeout, String output) { + Integer timeout, String output, Boolean roleInAudClaim) { final String caller = ctx.getApiName(); final String principalDomain = logPrincipalAndGetDomain(ctx); - validateRequest(ctx.request(), principalDomain, caller); + validateOIDCRequest(ctx.request(), principalDomain, caller); validate(nonce, TYPE_ENTITY_NAME, principalDomain, caller); validate(clientId, TYPE_SERVICE_NAME, principalDomain, caller); @@ -2025,9 +2036,9 @@ public Response getOIDCResponse(ResourceContext ctx, String responseType, String IdToken idToken = new IdToken(); idToken.setVersion(1); - idToken.setAudience(clientId); + idToken.setAudience(getIdTokenAudience(clientId, roleInAudClaim, idTokenGroups)); idToken.setSubject(principalName); - idToken.setIssuer(ztsOpenIDIssuer); + idToken.setIssuer(isOidcPortRequest(ctx.request().getLocalPort()) ? ztsOIDCPortIssuer : ztsOpenIDIssuer); idToken.setNonce(nonce); idToken.setGroups(idTokenGroups); idToken.setIssueTime(iat); @@ -2063,6 +2074,11 @@ public Response getOIDCResponse(ResourceContext ctx, String responseType, String } } + String getIdTokenAudience(final String clientId, Boolean includeGroup, List idTokenGroups) { + return (includeGroup == Boolean.TRUE && idTokenGroups != null && idTokenGroups.size() == 1) ? + clientId + ":" + idTokenGroups.get(0) : clientId; + } + List processIdTokenGroups(final String principalName, IdTokenRequest tokenRequest, final String clientIdDomainName, Boolean fullArn, final String principalDomain, final String caller) { @@ -4729,8 +4745,7 @@ public JWKList getJWKList(ResourceContext ctx, Boolean rfc) { final String caller = ctx.getApiName(); final String principalDomain = logPrincipalAndGetDomain(ctx); - validateRequest(ctx.request(), principalDomain, caller); - + validateOIDCRequest(ctx.request(), principalDomain, caller); return dataStore.getZtsJWKList(rfc); } @@ -4906,7 +4921,7 @@ public Status getStatus(ResourceContext ctx) { // validate our request as status request - validateRequest(ctx.request(), principalDomain, caller, true); + validateStatusRequest(ctx.request(), principalDomain, caller); // for now we're going to verify our certsigner connectivity // only if the administrator has configured it. without certsigner @@ -4941,14 +4956,18 @@ public Status getStatus(ResourceContext ctx) { return successServerStatus; } + boolean isOidcPortRequest(int port) { + return port == oidcPort && oidcPort != httpsPort; + } + @Override public OpenIDConfig getOpenIDConfig(ResourceContext ctx) { final String caller = ctx.getApiName(); final String principalDomain = logPrincipalAndGetDomain(ctx); - validateRequest(ctx.request(), principalDomain, caller); - return openIDConfig; + validateOIDCRequest(ctx.request(), principalDomain, caller); + return isOidcPortRequest(ctx.request().getLocalPort()) ? oidcPortConfig : openIDConfig; } @Override @@ -5001,11 +5020,19 @@ public Schema getRdlSchema(ResourceContext context) { } void validateRequest(HttpServletRequest request, final String principalDomain, final String caller) { - validateRequest(request, principalDomain, caller, false); + validateRequest(request, principalDomain, caller, false, false); + } + + void validateOIDCRequest(HttpServletRequest request, final String principalDomain, final String caller) { + validateRequest(request, principalDomain, caller, false, true); + } + + void validateStatusRequest(HttpServletRequest request, final String principalDomain, final String caller) { + validateRequest(request, principalDomain, caller, true, false); } void validateRequest(HttpServletRequest request, final String principalDomain, final String caller, - boolean statusRequest) { + boolean statusRequest, boolean oidcRequest) { // first validate if we're required process this over TLS only @@ -5014,7 +5041,7 @@ void validateRequest(HttpServletRequest request, final String principalDomain, f ZTSConsts.ZTS_UNKNOWN_DOMAIN, principalDomain); } - // second check if this is a status port so we can only + // second check if this is a status port, so we can only // process on status requests if (statusPort > 0 && statusPort != httpPort && statusPort != httpsPort) { @@ -5033,6 +5060,18 @@ void validateRequest(HttpServletRequest request, final String principalDomain, f caller, ZTSConsts.ZTS_UNKNOWN_DOMAIN, principalDomain); } } + + // final check is for oidc requests + + if (oidcPort > 0 && oidcPort != httpsPort) { + + // non oidc requests must not take place on the oidc port + + if (!oidcRequest && request.getLocalPort() == oidcPort) { + throw requestError("incorrect port number for a non-oidc request", + caller, ZTSConsts.ZTS_UNKNOWN_DOMAIN, principalDomain); + } + } } void validate(Object val, final String type, final String principalDomain, final String caller) { diff --git a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSResources.java b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSResources.java index c73fffee4d6..286090f96a5 100644 --- a/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSResources.java +++ b/servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSResources.java @@ -935,13 +935,14 @@ public Response getOIDCResponse( @Parameter(description = "optional signing key type - RSA or EC. Might be ignored if server doesn't have the requested type configured", required = false) @QueryParam("keyType") String keyType, @Parameter(description = "flag to indicate to use full arn in group claim (e.g. sports:role.deployer instead of deployer)", required = false) @QueryParam("fullArn") @DefaultValue("false") Boolean fullArn, @Parameter(description = "optional expiry period specified in seconds", required = false) @QueryParam("expiryTime") Integer expiryTime, - @Parameter(description = "optional output format of json", required = false) @QueryParam("output") String output) { + @Parameter(description = "optional output format of json", required = false) @QueryParam("output") String output, + @Parameter(description = "flag to indicate to include role name in the audience claim only if we have a single role in response", required = false) @QueryParam("roleInAudClaim") @DefaultValue("false") Boolean roleInAudClaim) { int code = ResourceException.OK; ResourceContext context = null; try { context = this.delegate.newResourceContext(this.servletContext, this.request, this.response, "getOIDCResponse"); context.authenticate(); - return this.delegate.getOIDCResponse(context, responseType, clientId, redirectUri, scope, state, nonce, keyType, fullArn, expiryTime, output); + return this.delegate.getOIDCResponse(context, responseType, clientId, redirectUri, scope, state, nonce, keyType, fullArn, expiryTime, output, roleInAudClaim); } catch (ResourceException e) { code = e.getCode(); switch (code) { diff --git a/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java b/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java index 150b7ef53db..6542560184d 100644 --- a/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java +++ b/servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java @@ -8276,15 +8276,15 @@ public void testValidateRequestSecureRequests() { // if secure requests is false, no check is done ztsImpl.validateRequest(request, "principal-domain", "test"); - ztsImpl.validateRequest(request, "principal-domain", "test", false); - ztsImpl.validateRequest(request, "principal-domain", "test", true); + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + ztsImpl.validateRequest(request, "principal-domain", "test", true, false); // should complete successfully since our request is true ztsImpl.secureRequestsOnly = true; ztsImpl.validateRequest(request, "principal-domain", "test"); - ztsImpl.validateRequest(request, "principal-domain", "test", false); - ztsImpl.validateRequest(request, "principal-domain", "test", true); + ztsImpl.validateRequest(request, "principal-domain", "test", false, true); + ztsImpl.validateRequest(request, "principal-domain", "test", true, true); } @Test @@ -8309,12 +8309,12 @@ public void testValidateRequestNonSecureRequests() { } catch (ResourceException ignored) { } try { - ztsImpl.validateRequest(request, "principal-domain", "test", false); + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); fail(); } catch (ResourceException ignored) { } try { - ztsImpl.validateRequest(request, "principal-domain", "test", true); + ztsImpl.validateRequest(request, "principal-domain", "test", true, true); fail(); } catch (ResourceException ignored) { } @@ -8330,21 +8330,44 @@ public void testValidateRequestStatusRequestPort() { ZTSImpl ztsImpl = new ZTSImpl(mockCloudStore, store); ztsImpl.secureRequestsOnly = true; - ztsImpl.statusPort = 8443; HttpServletRequest request = Mockito.mock(HttpServletRequest.class); Mockito.when(request.isSecure()).thenReturn(true); Mockito.when(request.getLocalPort()).thenReturn(4443); - // non-status requests are allowed on port 4443 + // with status port 0, all requests are ok + + ztsImpl.statusPort = 0; + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + ztsImpl.validateRequest(request, "principal-domain", "test", true, false); + + // with status set to equal to http port - all requests are ok + + ztsImpl.statusPort = 4080; + ztsImpl.httpPort = 4080; + + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + ztsImpl.validateRequest(request, "principal-domain", "test", true, false); + + // with status set to equal to https port - all requests are ok + + ztsImpl.statusPort = 4443; + ztsImpl.httpsPort = 4443; + + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + ztsImpl.validateRequest(request, "principal-domain", "test", true, false); + + // non-status requests are allowed on port 4443 with different status port + + ztsImpl.statusPort = 8443; ztsImpl.validateRequest(request, "principal-domain", "test"); - ztsImpl.validateRequest(request, "principal-domain", "test", false); + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); // status requests are not allowed on port 4443 try { - ztsImpl.validateRequest(request, "principal-domain", "test", true); + ztsImpl.validateStatusRequest(request, "principal-domain", "test"); fail(); } catch (ResourceException ignored) { } @@ -8361,6 +8384,7 @@ public void testValidateRequestRegularRequestPort() { ZTSImpl ztsImpl = new ZTSImpl(mockCloudStore, store); ztsImpl.secureRequestsOnly = true; ztsImpl.statusPort = 8443; + ztsImpl.oidcPort = 443; HttpServletRequest request = Mockito.mock(HttpServletRequest.class); Mockito.when(request.isSecure()).thenReturn(true); @@ -8368,7 +8392,7 @@ public void testValidateRequestRegularRequestPort() { // status requests are allowed on port 8443 - ztsImpl.validateRequest(request, "test", "principal-domain", true); + ztsImpl.validateStatusRequest(request, "test", "principal-domain"); // non-status requests are not allowed on port 8443 @@ -8379,7 +8403,64 @@ public void testValidateRequestRegularRequestPort() { } try { - ztsImpl.validateRequest(request, "principal-domain", "test", false); + ztsImpl.validateOIDCRequest(request, "principal-domain", "test"); + fail(); + } catch (ResourceException ignored) { + } + + try { + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + fail(); + } catch (ResourceException ignored) { + } + } + + @Test + public void testValidateRequestOIDCRequestPort() { + + ChangeLogStore structStore = new ZMSFileChangeLogStore("/tmp/zts_server_unit_tests/zts_root", + privateKey, "0"); + + DataStore store = new DataStore(structStore, null, ztsMetric); + + ZTSImpl ztsImpl = new ZTSImpl(mockCloudStore, store); + ztsImpl.secureRequestsOnly = true; + + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.isSecure()).thenReturn(true); + Mockito.when(request.getLocalPort()).thenReturn(443); + + // with oidc port 0, all requests are ok + + ztsImpl.oidcPort = 0; + ztsImpl.validateOIDCRequest(request, "test", "principal-domain"); + ztsImpl.validateRequest(request, "principal-domain", "test", false, true); + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + + // with status set to equal to https port - all requests are ok + + ztsImpl.oidcPort = 4443; + ztsImpl.httpsPort = 4443; + + ztsImpl.validateOIDCRequest(request, "test", "principal-domain"); + ztsImpl.validateRequest(request, "principal-domain", "test", false, true); + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); + + // oidc requests are allowed on port 443 + + ztsImpl.oidcPort = 443; + ztsImpl.validateOIDCRequest(request, "test", "principal-domain"); + + // non-oidc requests are not allowed on port 443 + + try { + ztsImpl.validateRequest(request, "principal-domain", "test"); + fail(); + } catch (ResourceException ignored) { + } + + try { + ztsImpl.validateRequest(request, "principal-domain", "test", false, false); fail(); } catch (ResourceException ignored) { } @@ -13220,7 +13301,12 @@ public void testChangeMessage() { @Test public void testGetOpenIDConfig() { - ResourceContext ctx = createResourceContext(null); + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.isSecure()).thenReturn(true); + Mockito.when(request.getLocalPort()).thenReturn(4443); + + RsrcCtxWrapper ctx = Mockito.mock(RsrcCtxWrapper.class); + Mockito.when(ctx.request()).thenReturn(request); OpenIDConfig openIDConfig = zts.getOpenIDConfig(ctx); assertNotNull(openIDConfig); @@ -13234,6 +13320,40 @@ public void testGetOpenIDConfig() { assertEquals(Collections.singletonList("public"), openIDConfig.getSubject_types_supported()); } + @Test + public void testGetOpendIDConfigOnOIDCPort() { + + ChangeLogStore structStore = new ZMSFileChangeLogStore("/tmp/zts_server_unit_tests/zts_root", + privateKey, "0"); + + DataStore store = new DataStore(structStore, null, ztsMetric); + System.setProperty(ZTSConsts.ZTS_PROP_OIDC_PORT_ISSUER, "https://athenz.io/zts/v1"); + System.setProperty(ZTSConsts.ZTS_PROP_OIDC_PORT, "443"); + + ZTSImpl ztsImpl = new ZTSImpl(mockCloudStore, store); + + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.isSecure()).thenReturn(true); + Mockito.when(request.getLocalPort()).thenReturn(443); + + RsrcCtxWrapper ctx = Mockito.mock(RsrcCtxWrapper.class); + Mockito.when(ctx.request()).thenReturn(request); + + OpenIDConfig openIDConfig = ztsImpl.getOpenIDConfig(ctx); + assertNotNull(openIDConfig); + + assertEquals("https://athenz.io/zts/v1", openIDConfig.getIssuer()); + assertEquals("https://athenz.io/zts/v1/oauth2/keys?rfc=true", openIDConfig.getJwks_uri()); + assertEquals("https://athenz.io/zts/v1/oauth2/auth", openIDConfig.getAuthorization_endpoint()); + + assertEquals(Collections.singletonList("RS256"), openIDConfig.getId_token_signing_alg_values_supported()); + assertEquals(Collections.singletonList("id_token"), openIDConfig.getResponse_types_supported()); + assertEquals(Collections.singletonList("public"), openIDConfig.getSubject_types_supported()); + + System.clearProperty(ZTSConsts.ZTS_PROP_OIDC_PORT_ISSUER); + System.clearProperty(ZTSConsts.ZTS_PROP_OIDC_PORT); + } + @Test public void testGetOAuthConfig() { @@ -13330,7 +13450,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "coretech", "https://localhost:4443", "openid", - null, "nonce", "RSA", null, null, null); + null, "nonce", "RSA", null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13341,7 +13461,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "unknown-domain.api", "https://localhost:4443", - "openid", null, "nonce", "EC", null, null, null); + "openid", null, "nonce", "EC", null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.NOT_FOUND); @@ -13357,7 +13477,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "coretech.backup", "https://localhost:4443/zts", - "openid", null, "nonce", "RSA", null, null, null); + "openid", null, "nonce", "RSA", null, null, null, Boolean.FALSE); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13367,7 +13487,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443", - "openid", "state", "nonce", null, null, null, null); + "openid", "state", "nonce", null, null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13377,7 +13497,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "token", "coretech.api", "https://localhost:4443/zts", - "openid", null, "nonce", "", null, null, null); + "openid", null, "nonce", "", null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13388,7 +13508,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "", null, "nonce", "rsa", null, null, null); + "", null, "nonce", "rsa", null, null, null, Boolean.TRUE); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13397,7 +13517,7 @@ public void testGetOIDCResponseFailures() { try { zts.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - null, null, "nonce", "unknown", Boolean.FALSE, null, null); + null, null, "nonce", "unknown", Boolean.FALSE, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.BAD_REQUEST); @@ -13413,23 +13533,27 @@ public void testGetOIDCResponseNoRulesGroups() { CloudStore cloudStore = new CloudStore(); cloudStore.setHttpClient(null); ZTSImpl ztsImpl = new ZTSImpl(cloudStore, store); + ztsImpl.oidcPort = 443; + ztsImpl.ztsOIDCPortIssuer = "https://athenz.io"; + // set back to our zts rsa private key System.setProperty(FilePrivateKeyStore.ATHENZ_PROP_PRIVATE_KEY, "src/test/resources/unit_test_zts_private.pem"); Principal principal = SimplePrincipal.create("user_domain", "user", "v=U1;d=user_domain;n=user;s=signature", 0, null); ResourceContext context = createResourceContext(principal); + Mockito.when(context.request().getLocalPort()).thenReturn(443); SignedDomain signedDomain = createSignedDomain("coretech", "sports", "api", true); store.processSignedDomain(signedDomain, false); Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid", null, "nonce", "RSA", Boolean.FALSE, null, null); + "openid", null, "nonce", "RSA", Boolean.FALSE, null, null, Boolean.TRUE); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); assertEquals("user_domain.user", claims.getBody().getSubject()); assertEquals("coretech.api", claims.getBody().getAudience()); - assertEquals(ztsImpl.ztsOpenIDIssuer, claims.getBody().getIssuer()); + assertEquals("https://athenz.io", claims.getBody().getIssuer()); List groups = (List) claims.getBody().get("groups"); assertNull(groups); } @@ -13462,7 +13586,7 @@ public void testGetOIDCResponseGroups() { // get all the groups Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid groups", null, "nonce", "EC", null, null, null); + "openid groups", null, "nonce", "EC", null, null, null, null); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); assertEquals("user_domain.user", claims.getBody().getSubject()); @@ -13478,7 +13602,7 @@ public void testGetOIDCResponseGroups() { // get only one of the groups and include state response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:group.dev-team", "valid-state", "nonce", "RSA", null, null, null); + "openid coretech:group.dev-team", "valid-state", "nonce", "RSA", null, null, null, null); assertEquals(response.getStatus(), ResourceException.FOUND); String location = response.getHeaderString("Location"); final String stateComp = "&state=valid-state"; @@ -13506,7 +13630,7 @@ public void testGetOIDCResponseGroups() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:group.eng-team", null, "nonce", null, Boolean.FALSE, null, null); + "openid coretech:group.eng-team", null, "nonce", null, Boolean.FALSE, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.FORBIDDEN); @@ -13545,7 +13669,7 @@ public void testGetOIDCResponseGroupsDifferentDomain() { // get all the groups from the coretech domain Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid groups weather:domain", null, "nonce", "EC", null, null, null); + "openid groups weather:domain", null, "nonce", "EC", null, null, null, null); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); assertEquals("user_domain.user", claims.getBody().getSubject()); @@ -13562,7 +13686,7 @@ public void testGetOIDCResponseGroupsDifferentDomain() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid groups unknown-domain:domain", null, "nonce", "EC", null, null, null); + "openid groups unknown-domain:domain", null, "nonce", "EC", null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.NOT_FOUND); @@ -13632,8 +13756,9 @@ public void testGetOIDCResponseGroupsMultipleDomains() { // get all the groups - Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid groups coretech:domain weather:domain homepage:domain", null, "nonce", "EC", null, null, null); + Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", + "https://localhost:4443/zts", "openid groups coretech:domain weather:domain homepage:domain", + null, "nonce", "EC", null, null, null, null); assertEquals(response.getStatus(), ResourceException.FOUND); String location = response.getHeaderString("Location"); @@ -13662,8 +13787,9 @@ public void testGetOIDCResponseGroupsMultipleDomains() { // get only one of the groups and include state - response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:group.dev-team weather:group.pe-team", "valid-state", "nonce", "RSA", null, null, null); + response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", + "https://localhost:4443/zts", "openid coretech:group.dev-team weather:group.pe-team", + "valid-state", "nonce", "RSA", null, null, null, Boolean.FALSE); assertEquals(response.getStatus(), ResourceException.FOUND); location = response.getHeaderString("Location"); String stateComp = "&state=valid-state"; @@ -13692,7 +13818,8 @@ public void testGetOIDCResponseGroupsMultipleDomains() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:group.eng-team weather:group.eng-team", null, "nonce", null, Boolean.FALSE, null, null); + "openid coretech:group.eng-team weather:group.eng-team", null, "nonce", null, + Boolean.FALSE, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.FORBIDDEN); @@ -13703,7 +13830,8 @@ public void testGetOIDCResponseGroupsMultipleDomains() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:group.eng finance:group.eng", null, "nonce", "EC", Boolean.FALSE, null, null); + "openid coretech:group.eng finance:group.eng", null, "nonce", "EC", Boolean.FALSE, + null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.NOT_FOUND); @@ -13713,7 +13841,8 @@ public void testGetOIDCResponseGroupsMultipleDomains() { // requests from domains where the user is not part of any groups response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid groups homepage:domain fantasy:domain", "valid-state", "nonce", "RSA", null, null, null); + "openid groups homepage:domain fantasy:domain", "valid-state", "nonce", "RSA", null, + null, null, null); assertEquals(response.getStatus(), ResourceException.FOUND); location = response.getHeaderString("Location"); stateComp = "&state=valid-state"; @@ -13738,15 +13867,19 @@ public void testGetOIDCResponseGroupsMultipleDomains() { @Test public void testGetOIDCResponseRolesWithJson() { - testGetOIDCResponseRoles("json"); + testGetOIDCResponseRoles("json", null); + testGetOIDCResponseRoles("json", Boolean.FALSE); + testGetOIDCResponseRoles("json", Boolean.TRUE); } @Test public void testGetOIDCResponseRolesRFC() { - testGetOIDCResponseRoles(null); + testGetOIDCResponseRoles(null, null); + testGetOIDCResponseRoles(null, Boolean.FALSE); + testGetOIDCResponseRoles(null, Boolean.TRUE); } - private void testGetOIDCResponseRoles(final String output) { + private void testGetOIDCResponseRoles(final String output, Boolean roleInAudClaim) { System.setProperty(FilePrivateKeyStore.ATHENZ_PROP_PRIVATE_KEY, "src/test/resources/unit_test_zts_at_private.pem"); @@ -13768,11 +13901,11 @@ private void testGetOIDCResponseRoles(final String output) { // get all the roles Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid roles", null, "nonce", "", null, null, output); + "openid roles", null, "nonce", "", null, null, output, roleInAudClaim); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), output); assertNotNull(claims); assertEquals("user_domain.user", claims.getBody().getSubject()); - assertEquals("coretech.api", claims.getBody().getAudience()); + assertEquals((roleInAudClaim == Boolean.TRUE) ? "coretech.api:writers" : "coretech.api", claims.getBody().getAudience()); assertEquals("nonce", claims.getBody().get("nonce", String.class)); assertEquals(ztsImpl.ztsOpenIDIssuer, claims.getBody().getIssuer()); List userRoles = (List) claims.getBody().get("groups"); @@ -13784,11 +13917,11 @@ private void testGetOIDCResponseRoles(final String output) { // which should be honored response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 30 * 60, output); + "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 30 * 60, output, roleInAudClaim); claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), output); assertNotNull(claims); assertEquals("user_domain.user", claims.getBody().getSubject()); - assertEquals("coretech.api", claims.getBody().getAudience()); + assertEquals((roleInAudClaim == Boolean.TRUE) ? "coretech.api:writers" : "coretech.api", claims.getBody().getAudience()); assertEquals(ztsImpl.ztsOpenIDIssuer, claims.getBody().getIssuer()); userRoles = (List) claims.getBody().get("groups"); assertNotNull(userRoles); @@ -13800,7 +13933,7 @@ private void testGetOIDCResponseRoles(final String output) { // expiry is still set to 1 hour response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 120 * 60, output); + "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 120 * 60, output, roleInAudClaim); claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), output); assertNotNull(claims); assertEquals(claims.getBody().getExpiration().getTime() - claims.getBody().getIssuedAt().getTime(), 60 * 60 * 1000); @@ -13812,7 +13945,7 @@ private void testGetOIDCResponseRoles(final String output) { ztsImpl.userDomain = "user-other-domain"; response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 120 * 60, output); + "openid coretech:role.writers", null, "nonce", "RSA", Boolean.FALSE, 120 * 60, output, roleInAudClaim); claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), output); assertNotNull(claims); assertEquals(claims.getBody().getExpiration().getTime() - claims.getBody().getIssuedAt().getTime(), 120 * 60 * 1000); @@ -13825,7 +13958,7 @@ private void testGetOIDCResponseRoles(final String output) { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.eng-team", null, "nonce", "EC", Boolean.FALSE, null, output); + "openid coretech:role.eng-team", null, "nonce", "EC", Boolean.FALSE, null, output, roleInAudClaim); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.FORBIDDEN); @@ -13880,7 +14013,7 @@ public void testGetOIDCResponseRolesDifferentDomain() { // get all the roles Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid roles weather:domain", null, "nonce", "", null, null, null); + "openid roles weather:domain", null, "nonce", "", null, null, null, Boolean.FALSE); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); @@ -13897,7 +14030,7 @@ public void testGetOIDCResponseRolesDifferentDomain() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid roles unknown-domain:domain", null, "nonce", "EC", null, null, null); + "openid roles unknown-domain:domain", null, "nonce", "EC", null, null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.NOT_FOUND); @@ -13935,8 +14068,9 @@ public void testGetOIDCResponseRolesMultipleDomains() { // get all the roles - Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid roles coretech:domain weather:domain homepage:domain", null, "nonce", "", null, null, null); + Response response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", + "https://localhost:4443/zts", "openid roles coretech:domain weather:domain homepage:domain", + null, "nonce", "", null, null, null, null); Jws claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); @@ -13953,7 +14087,8 @@ public void testGetOIDCResponseRolesMultipleDomains() { // specific the roles explicitly response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.writers weather:role.writers", null, "nonce", "RSA", Boolean.FALSE, null, null); + "openid coretech:role.writers weather:role.writers", null, "nonce", "RSA", Boolean.FALSE, + null, null, null); assertEquals(response.getStatus(), ResourceException.FOUND); claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); @@ -13971,7 +14106,8 @@ public void testGetOIDCResponseRolesMultipleDomains() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.eng weather:role.eng", null, "nonce", "EC", Boolean.FALSE, null, null); + "openid coretech:role.eng weather:role.eng", null, "nonce", "EC", Boolean.FALSE, + null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.FORBIDDEN); @@ -13982,7 +14118,8 @@ public void testGetOIDCResponseRolesMultipleDomains() { try { ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid coretech:role.eng finance:role.eng", null, "nonce", "EC", Boolean.FALSE, null, null); + "openid coretech:role.eng finance:role.eng", null, "nonce", "EC", Boolean.FALSE, + null, null, null); fail(); } catch (ResourceException ex) { assertEquals(ex.getCode(), ResourceException.NOT_FOUND); @@ -13992,7 +14129,7 @@ public void testGetOIDCResponseRolesMultipleDomains() { // requests from domains where the user is not part of any role response = ztsImpl.getOIDCResponse(context, "id_token", "coretech.api", "https://localhost:4443/zts", - "openid roles homepage:domain fantasy:domain", null, "nonce", "RSA", null, null, null); + "openid roles homepage:domain fantasy:domain", null, "nonce", "RSA", null, null, null, null); claims = getClaimsFromResponse(response, ztsImpl.privateKey.getKey(), null); assertNotNull(claims); @@ -14209,4 +14346,54 @@ public void testGetIdTokenGroupsFromRoles() { assertTrue(resGroups.contains("coretech:role.reader")); assertTrue(resGroups.contains("coretech:role.writer")); } + + @Test + public void testIsOidcPortRequest() { + + ChangeLogStore structStore = new ZMSFileChangeLogStore("/tmp/zts_server_unit_tests/zts_root", + privateKey, "0"); + + DataStore store = new DataStore(structStore, null, ztsMetric); + + ZTSImpl ztsImpl = new ZTSImpl(mockCloudStore, store); + ztsImpl.oidcPort = 0; + ztsImpl.httpsPort = 4443; + + assertFalse(ztsImpl.isOidcPortRequest(443)); + assertFalse(ztsImpl.isOidcPortRequest(4443)); + + ztsImpl.oidcPort = 443; + ztsImpl.httpsPort = 4443; + + assertTrue(ztsImpl.isOidcPortRequest(443)); + assertFalse(ztsImpl.isOidcPortRequest(4443)); + + ztsImpl.oidcPort = 4443; + ztsImpl.httpsPort = 4443; + + assertFalse(ztsImpl.isOidcPortRequest(443)); + assertFalse(ztsImpl.isOidcPortRequest(4443)); + } + + @Test + public void testGetIdTokenAudience() { + assertEquals(zts.getIdTokenAudience("id", null, null), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.FALSE, null), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.TRUE, null), "id"); + + List idTokenGroups = new ArrayList<>(); + assertEquals(zts.getIdTokenAudience("id", null, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.FALSE, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.TRUE, idTokenGroups), "id"); + + idTokenGroups.add("athenz:role.oidc"); + assertEquals(zts.getIdTokenAudience("id", null, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.FALSE, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.TRUE, idTokenGroups), "id:athenz:role.oidc"); + + idTokenGroups.add("athenz:role.oidc2"); + assertEquals(zts.getIdTokenAudience("id", null, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.FALSE, idTokenGroups), "id"); + assertEquals(zts.getIdTokenAudience("id", Boolean.TRUE, idTokenGroups), "id"); + } } diff --git a/utils/zts-idtoken/zts-idtoken.go b/utils/zts-idtoken/zts-idtoken.go index 0ce38c0631d..f60f4fa382c 100644 --- a/utils/zts-idtoken/zts-idtoken.go +++ b/utils/zts-idtoken/zts-idtoken.go @@ -23,7 +23,7 @@ var ( ) func usage() { - fmt.Println("usage: zts-idtoken -zts -scope -redirect-uri -nonce -client-id -state -key-type -format ") + fmt.Println("usage: zts-idtoken -zts -scope -redirect-uri -nonce -client-id -state -key-type -format [-full-arn=true] [-role-in-aud-claim=true]") fmt.Println(" := -svc-key-file -svc-cert-file [-svc-cacert-file ]") fmt.Println(" zts-idtoken -validate -id-token -conf [-claims]") os.Exit(1) @@ -39,7 +39,7 @@ func printVersion() { func main() { var clientId, scope, state, redirectUri, nonce, svcKeyFile, svcCertFile, svcCACertFile, ztsURL, conf, idToken, keyType, format string - var proxy, validate, claims, showVersion, fullArn bool + var proxy, validate, claims, showVersion, fullArn, roleInAudClaim bool var expireTime int flag.StringVar(&clientId, "client-id", "", "client-id for the token") flag.StringVar(&redirectUri, "redirect-uri", "", "redirect uri registered for the client-id") @@ -60,6 +60,7 @@ func main() { flag.BoolVar(&showVersion, "version", false, "Show version") flag.StringVar(&format, "format", "token", "Output format: token | kubectl") flag.IntVar(&expireTime, "expire-time", 60, "token expire time in minutes") + flag.BoolVar(&roleInAudClaim, "role-in-aud-claim", false, "include role name in aud claim") flag.Parse() if showVersion { @@ -70,7 +71,7 @@ func main() { if validate { validateIdToken(idToken, conf, claims) } else { - fetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, format, &fullArn, proxy, expireTime) + fetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, format, &fullArn, proxy, expireTime, &roleInAudClaim) } } @@ -115,7 +116,7 @@ func validateIdToken(idToken, conf string, showClaims bool) { fmt.Println("Id Token successfully validated") } -func fetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, format string, fullArn *bool, proxy bool, expireTime int) { +func fetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, format string, fullArn *bool, proxy bool, expireTime int, roleInAudClaim *bool) { defaultConfig, _ := athenzutils.ReadDefaultConfig() // check to see if we need to use zts url from our default config file @@ -129,7 +130,7 @@ func fetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redi // need to convert minutes into seconds expireTimeSecs := int32(expireTime) * 60 - idToken, err := athenzutils.FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, fullArn, proxy, &expireTimeSecs) + idToken, err := athenzutils.FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, redirectUri, scope, nonce, state, keyType, fullArn, proxy, &expireTimeSecs, roleInAudClaim) if err != nil { log.Fatalf("unable to fetch id token: %v\n", err) }