Skip to content

Commit

Permalink
[java] Allowing setting SSL context in client config for HttpClient (#…
Browse files Browse the repository at this point in the history
…12874)

* [java] Allowing setting SSL context in client config for HttpClient

Fixes #12869

* [java] Fix refactor error

* fix linter complaints

---------

Co-authored-by: titusfortner <[email protected]>
Co-authored-by: Titus Fortner <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2023
1 parent 33c4122 commit b9bdff1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 11 deletions.
52 changes: 43 additions & 9 deletions java/src/org/openqa/selenium/remote/http/ClientConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.net.URISyntaxException;
import java.net.URL;
import java.time.Duration;
import javax.net.ssl.SSLContext;
import org.openqa.selenium.Credentials;
import org.openqa.selenium.internal.Require;

Expand All @@ -38,24 +39,28 @@ public class ClientConfig {
private final Proxy proxy;
private final Credentials credentials;

private final SSLContext sslContext;

protected ClientConfig(
URI baseUri,
Duration connectionTimeout,
Duration readTimeout,
Filter filters,
Proxy proxy,
Credentials credentials) {
Credentials credentials,
SSLContext sslContext) {
this.baseUri = baseUri;
this.connectionTimeout = Require.nonNegative("Connection timeout", connectionTimeout);
this.readTimeout = Require.nonNegative("Read timeout", readTimeout);
this.filters = Require.nonNull("Filters", filters);
this.proxy = proxy;
this.credentials = credentials;
this.sslContext = sslContext;
}

public static ClientConfig defaultConfig() {
return new ClientConfig(
null, Duration.ofSeconds(10), Duration.ofMinutes(3), DEFAULT_FILTER, null, null);
null, Duration.ofSeconds(10), Duration.ofMinutes(3), DEFAULT_FILTER, null, null, null);
}

public ClientConfig baseUri(URI baseUri) {
Expand All @@ -65,7 +70,8 @@ public ClientConfig baseUri(URI baseUri) {
readTimeout,
filters,
proxy,
credentials);
credentials,
sslContext);
}

public ClientConfig baseUrl(URL baseUrl) {
Expand Down Expand Up @@ -95,7 +101,8 @@ public ClientConfig connectionTimeout(Duration timeout) {
readTimeout,
filters,
proxy,
credentials);
credentials,
sslContext);
}

public Duration connectionTimeout() {
Expand All @@ -109,7 +116,8 @@ public ClientConfig readTimeout(Duration timeout) {
Require.nonNull("Read timeout", timeout),
filters,
proxy,
credentials);
credentials,
sslContext);
}

public Duration readTimeout() {
Expand All @@ -124,12 +132,19 @@ public ClientConfig withFilter(Filter filter) {
readTimeout,
filter.andThen(DEFAULT_FILTER),
proxy,
credentials);
credentials,
sslContext);
}

public ClientConfig withRetries() {
return new ClientConfig(
baseUri, connectionTimeout, readTimeout, filters.andThen(RETRY_FILTER), proxy, credentials);
baseUri,
connectionTimeout,
readTimeout,
filters.andThen(RETRY_FILTER),
proxy,
credentials,
sslContext);
}

public Filter filter() {
Expand All @@ -143,7 +158,8 @@ public ClientConfig proxy(Proxy proxy) {
readTimeout,
filters,
Require.nonNull("Proxy", proxy),
credentials);
credentials,
sslContext);
}

public Proxy proxy() {
Expand All @@ -157,13 +173,29 @@ public ClientConfig authenticateAs(Credentials credentials) {
readTimeout,
filters,
proxy,
Require.nonNull("Credentials", credentials));
Require.nonNull("Credentials", credentials),
sslContext);
}

public Credentials credentials() {
return credentials;
}

public ClientConfig sslContext(SSLContext sslContext) {
return new ClientConfig(
baseUri,
connectionTimeout,
readTimeout,
filters,
proxy,
credentials,
Require.nonNull("SSL Context", sslContext));
}

public SSLContext sslContext() {
return sslContext;
}

@Override
public String toString() {
return "ClientConfig{"
Expand All @@ -179,6 +211,8 @@ public String toString() {
+ proxy
+ ", credentials="
+ credentials
+ ", sslcontext="
+ sslContext
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLContext;
import org.openqa.selenium.Credentials;
import org.openqa.selenium.TimeoutException;
import org.openqa.selenium.UsernameAndPassword;
Expand Down Expand Up @@ -144,6 +145,11 @@ public void connectFailed(URI uri, SocketAddress sa, IOException ioe) {
builder = builder.proxy(proxySelector);
}

SSLContext sslContext = config.sslContext();
if (sslContext != null) {
builder.sslContext(sslContext);
}

this.client = builder.build();
}

Expand Down
52 changes: 50 additions & 2 deletions java/test/org/openqa/selenium/grid/router/ProxyWebsocketTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.net.Socket;
import java.net.URISyntaxException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.Collections;
import java.util.Optional;
Expand All @@ -32,6 +37,10 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedTrustManager;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -49,6 +58,7 @@
import org.openqa.selenium.grid.sessionmap.local.LocalSessionMap;
import org.openqa.selenium.netty.server.NettyServer;
import org.openqa.selenium.remote.SessionId;
import org.openqa.selenium.remote.http.ClientConfig;
import org.openqa.selenium.remote.http.HttpClient;
import org.openqa.selenium.remote.http.HttpHandler;
import org.openqa.selenium.remote.http.HttpRequest;
Expand Down Expand Up @@ -181,7 +191,10 @@ public void onText(CharSequence data) {
@ParameterizedTest
@MethodSource("data")
void shouldBeAbleToSendMessagesOverSecureWebSocket(Supplier<String> values)
throws URISyntaxException, InterruptedException {
throws URISyntaxException,
InterruptedException,
NoSuchAlgorithmException,
KeyManagementException {
setFields(values);
Config secureConfig =
new MapConfig(ImmutableMap.of("server", ImmutableMap.of("https-self-signed", true)));
Expand All @@ -207,11 +220,46 @@ void shouldBeAbleToSendMessagesOverSecureWebSocket(Supplier<String> values)
new ImmutableCapabilities(),
Instant.now()));

final TrustManager trustManager =
new X509ExtendedTrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) {}

@Override
public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) {}

@Override
public void checkClientTrusted(
X509Certificate[] chain, String authType, SSLEngine engine) {}

@Override
public void checkServerTrusted(
X509Certificate[] chain, String authType, SSLEngine engine) {}

@Override
public java.security.cert.X509Certificate[] getAcceptedIssuers() {
return new java.security.cert.X509Certificate[0];
}

@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) {}

@Override
public void checkServerTrusted(
java.security.cert.X509Certificate[] chain, String authType) {}
};

SSLContext sslContext = SSLContext.getInstance("SSL");
sslContext.init(null, new TrustManager[] {trustManager}, new SecureRandom());

CountDownLatch latch = new CountDownLatch(1);
AtomicReference<String> text = new AtomicReference<>();
try (WebSocket socket =
clientFactory
.createClient(secureProxyServer.getUrl())
.createClient(
ClientConfig.defaultConfig()
.baseUrl(secureProxyServer.getUrl())
.sslContext(sslContext))
.openSocket(
new HttpRequest(GET, String.format("/session/%s/" + protocol, id)),
new WebSocket.Listener() {
Expand Down

0 comments on commit b9bdff1

Please sign in to comment.