Skip to content

Commit

Permalink
[grid] Allow Node implementation to be pluggable
Browse files Browse the repository at this point in the history
This uses the same mechanism as used by `EventBus` and
`SessionMap`. By default we use a local instance, but --- hey! ---
anything is possible.
  • Loading branch information
shs96c committed Jun 18, 2020
1 parent c3e1b9e commit 62d3333
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@
import org.openqa.selenium.grid.config.Role;
import org.openqa.selenium.grid.distributor.Distributor;
import org.openqa.selenium.grid.distributor.local.LocalDistributor;
import org.openqa.selenium.grid.docker.DockerOptions;
import org.openqa.selenium.grid.graphql.GraphqlHandler;
import org.openqa.selenium.grid.log.LoggingOptions;
import org.openqa.selenium.grid.node.Node;
import org.openqa.selenium.grid.node.ProxyNodeCdp;
import org.openqa.selenium.grid.node.config.NodeOptions;
import org.openqa.selenium.grid.node.local.LocalNode;
import org.openqa.selenium.grid.router.Router;
import org.openqa.selenium.grid.server.BaseServerOptions;
import org.openqa.selenium.grid.server.EventBusOptions;
Expand Down Expand Up @@ -148,18 +146,7 @@ protected void execute(Config config) {
Route.post("/graphql").to(() -> graphqlHandler),
Route.get("/readyz").to(() -> readinessCheck));

LocalNode.Builder nodeBuilder = LocalNode.builder(
tracer,
bus,
localhost,
localhost,
null)
.maximumConcurrentSessions(Runtime.getRuntime().availableProcessors() * 3);

new NodeOptions(config).configure(tracer, clientFactory, nodeBuilder);
new DockerOptions(config).configure(tracer, clientFactory, nodeBuilder);

Node node = nodeBuilder.build();
Node node = new NodeOptions(config).getNode();
combinedHandler.addHandler(node);
distributor.add(node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ java_library(
visibility = [
"//java/server/src/org/openqa/selenium/grid/commands:__pkg__",
"//java/server/src/org/openqa/selenium/grid/node/httpd:__pkg__",
"//java/server/src/org/openqa/selenium/grid/node/local:__pkg__",
],
deps = [
"//java:auto-service",
Expand All @@ -16,7 +17,6 @@ java_library(
"//java/server/src/org/openqa/selenium/grid/config",
"//java/server/src/org/openqa/selenium/grid/data",
"//java/server/src/org/openqa/selenium/grid/node",
"//java/server/src/org/openqa/selenium/grid/node/local",
artifact("com.beust:jcommander"),
artifact("com.google.guava:guava"),
],
Expand Down
17 changes: 13 additions & 4 deletions java/server/src/org/openqa/selenium/grid/docker/DockerOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.openqa.selenium.grid.docker;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import org.openqa.selenium.Capabilities;
import org.openqa.selenium.Platform;
Expand All @@ -26,7 +28,7 @@
import org.openqa.selenium.docker.Image;
import org.openqa.selenium.grid.config.Config;
import org.openqa.selenium.grid.config.ConfigException;
import org.openqa.selenium.grid.node.local.LocalNode;
import org.openqa.selenium.grid.node.SessionFactory;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.json.Json;
import org.openqa.selenium.remote.http.ClientConfig;
Expand All @@ -36,7 +38,9 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -102,9 +106,12 @@ private boolean isEnabled(HttpClient.Factory clientFactory) {
return new Docker(client).isSupported();
}

public void configure(Tracer tracer, HttpClient.Factory clientFactory, LocalNode.Builder node) {
public Map<Capabilities, Collection<SessionFactory>> getDockerSessionFactories(
Tracer tracer,
HttpClient.Factory clientFactory) {

if (!isEnabled(clientFactory)) {
return;
return ImmutableMap.of();
}

List<String> allConfigs = config.getAll(DOCKER_SECTION, "configs")
Expand All @@ -128,17 +135,19 @@ public void configure(Tracer tracer, HttpClient.Factory clientFactory, LocalNode
loadImages(docker, kinds.keySet().toArray(new String[0]));

int maxContainerCount = Runtime.getRuntime().availableProcessors();
ImmutableMultimap.Builder<Capabilities, SessionFactory> factories = ImmutableMultimap.builder();
kinds.forEach((name, caps) -> {
Image image = docker.getImage(name);
for (int i = 0; i < maxContainerCount; i++) {
node.add(caps, new DockerSessionFactory(tracer, clientFactory, docker, image, caps));
factories.put(caps, new DockerSessionFactory(tracer, clientFactory, docker, image, caps));
}
LOG.info(String.format(
"Mapping %s to docker image %s %d times",
caps,
name,
maxContainerCount));
});
return factories.build().asMap();
}

private void loadImages(Docker docker, String... imageNames) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ java_library(
"//java/server/src/org/openqa/selenium/grid/config",
"//java/server/src/org/openqa/selenium/grid/data",
"//java/server/src/org/openqa/selenium/grid/node",
"//java/server/src/org/openqa/selenium/grid/node/local",
artifact("com.google.guava:guava"),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,30 @@
package org.openqa.selenium.grid.node.config;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import org.openqa.selenium.Capabilities;
import org.openqa.selenium.WebDriverInfo;
import org.openqa.selenium.grid.config.Config;
import org.openqa.selenium.grid.config.ConfigException;
import org.openqa.selenium.grid.node.Node;
import org.openqa.selenium.grid.node.SessionFactory;
import org.openqa.selenium.grid.node.local.LocalNode;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.json.Json;
import org.openqa.selenium.json.JsonOutput;
import org.openqa.selenium.remote.http.HttpClient;
import org.openqa.selenium.remote.service.DriverService;
import org.openqa.selenium.remote.tracing.Tracer;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
Expand All @@ -47,6 +50,8 @@ public class NodeOptions {

private static final Logger LOG = Logger.getLogger(NodeOptions.class.getName());
private static final Json JSON = new Json();
private static final String DEFAULT_IMPL = "org.openqa.selenium.grid.node.local.LocalNodeFactory";

private final Config config;

public NodeOptions(Config config) {
Expand All @@ -63,80 +68,92 @@ public Optional<URI> getPublicGridUri() {
});
}

public void configure(Tracer tracer, HttpClient.Factory httpClientFactory, LocalNode.Builder node) {
public Node getNode() {
return config.getClass("node", "implementation", Node.class, DEFAULT_IMPL);
}

public Map<Capabilities, Collection<SessionFactory>> getSessionFactories(
/* Danger! Java stereotype ahead! */ Function<WebDriverInfo, Collection<SessionFactory>> factoryFactory) {

int maxSessions = Math.min(
config.getInt("node", "max-concurrent-sessions").orElse(Runtime.getRuntime().availableProcessors()),
Runtime.getRuntime().availableProcessors());

Map<WebDriverInfo, Collection<SessionFactory>> allDrivers = discoverDrivers(tracer, httpClientFactory, maxSessions);
Map<WebDriverInfo, Collection<SessionFactory>> allDrivers = discoverDrivers(maxSessions, factoryFactory);

// If drivers have been specified, use those.
List<String> drivers = config.getAll("node", "drivers").orElse(new ArrayList<>()).stream()
.map(String::toLowerCase)
.collect(Collectors.toList());

ImmutableMultimap.Builder<Capabilities, SessionFactory> sessionFactories = ImmutableMultimap.builder();

if (!drivers.isEmpty()) {
allDrivers.entrySet().stream()
.filter(entry -> drivers.contains(entry.getKey().getDisplayName().toLowerCase()))
.sorted(Comparator.comparing(entry -> entry.getKey().getDisplayName().toLowerCase()))
.peek(this::report)
.forEach(entry -> entry.getValue().forEach(factory -> node.add(entry.getKey().getCanonicalCapabilities(), factory)));
.forEach(entry -> sessionFactories.putAll(entry.getKey().getCanonicalCapabilities(), entry.getValue()));

return;
return sessionFactories.build().asMap();
}

if (!config.getBool("node", "detect-drivers").orElse(false)) {
return;
return sessionFactories.build().asMap();
}

allDrivers.entrySet().stream()
.peek(this::report)
.forEach(entry -> entry.getValue().forEach(factory -> node.add(entry.getKey().getCanonicalCapabilities(), factory)));
}
.forEach(entry -> sessionFactories.putAll(entry.getKey().getCanonicalCapabilities(), entry.getValue()));

private void report(Map.Entry<WebDriverInfo, Collection<SessionFactory>> entry) {
StringBuilder caps = new StringBuilder();
try (JsonOutput out = JSON.newOutput(caps)) {
out.setPrettyPrint(false);
out.write(entry.getKey().getCanonicalCapabilities());
}

LOG.info(String.format(
"Adding %s for %s %d times",
entry.getKey().getDisplayName(),
caps.toString().replaceAll("\\s+", " "),
entry.getValue().size()));
return sessionFactories.build().asMap();
}

private Map<WebDriverInfo, Collection<SessionFactory>> discoverDrivers(
Tracer tracer,
HttpClient.Factory clientFactory,
int maxSessions) {
int maxSessions,
Function<WebDriverInfo, Collection<SessionFactory>> factoryFactory) {

if (!config.getBool("node", "detect-drivers").orElse(false)) {
return ImmutableMap.of();
}

// We don't expect duplicates, but they're fine
List<WebDriverInfo> infos =
StreamSupport.stream(ServiceLoader.load(WebDriverInfo.class).spliterator(), false)
.filter(WebDriverInfo::isAvailable)
.sorted(Comparator.comparing(info -> info.getDisplayName().toLowerCase()))
.collect(Collectors.toList());

// Same
List<DriverService.Builder> builders = new ArrayList<>();
List<DriverService.Builder<?, ?>> builders = new ArrayList<>();
ServiceLoader.load(DriverService.Builder.class).forEach(builders::add);

HashMultimap<WebDriverInfo, SessionFactory> toReturn = HashMultimap.create();
Multimap<WebDriverInfo, SessionFactory> toReturn = HashMultimap.create();
infos.forEach(info -> {
Capabilities caps = info.getCanonicalCapabilities();
builders.stream()
.filter(builder -> builder.score(caps) > 0)
.forEach(builder -> {
for (int i = 0; i < Math.min(info.getMaximumSimultaneousSessions(), maxSessions); i++) {

DriverService.Builder freePortBuilder = builder.usingAnyFreePort();
toReturn.put(info, new DriverServiceSessionFactory(
tracer,
clientFactory, c -> freePortBuilder.score(c) > 0,
freePortBuilder));
toReturn.putAll(info, factoryFactory.apply(info));
}
});
});

return toReturn.asMap();
}

private void report(Map.Entry<WebDriverInfo, Collection<SessionFactory>> entry) {
StringBuilder caps = new StringBuilder();
try (JsonOutput out = JSON.newOutput(caps)) {
out.setPrettyPrint(false);
out.write(entry.getKey().getCanonicalCapabilities());
}

LOG.info(String.format(
"Adding %s for %s %d times",
entry.getKey().getDisplayName(),
caps.toString().replaceAll("\\s+", " "),
entry.getValue().size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ java_library(
"//java/server/src/org/openqa/selenium/grid/log",
"//java/server/src/org/openqa/selenium/grid/node",
"//java/server/src/org/openqa/selenium/grid/node/config",
"//java/server/src/org/openqa/selenium/grid/node/local",
"//java/server/src/org/openqa/selenium/grid/server",
"//java/server/src/org/openqa/selenium/grid/web",
"//java/server/src/org/openqa/selenium/netty/server",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
import org.openqa.selenium.grid.config.Config;
import org.openqa.selenium.grid.config.Role;
import org.openqa.selenium.grid.data.NodeStatusEvent;
import org.openqa.selenium.grid.docker.DockerOptions;
import org.openqa.selenium.grid.log.LoggingOptions;
import org.openqa.selenium.grid.node.Node;
import org.openqa.selenium.grid.node.ProxyNodeCdp;
import org.openqa.selenium.grid.node.config.NodeOptions;
import org.openqa.selenium.grid.node.local.LocalNode;
import org.openqa.selenium.grid.server.BaseServerOptions;
import org.openqa.selenium.grid.server.EventBusOptions;
import org.openqa.selenium.grid.server.NetworkOptions;
Expand Down Expand Up @@ -115,17 +114,7 @@ protected void execute(Config config) {

NodeOptions nodeOptions = new NodeOptions(config);

LocalNode.Builder builder = LocalNode.builder(
tracer,
bus,
serverOptions.getExternalUri(),
nodeOptions.getPublicGridUri().orElseGet(serverOptions::getExternalUri),
serverOptions.getRegistrationSecret());

nodeOptions.configure(tracer, clientFactory, builder);
new DockerOptions(config).configure(tracer, clientFactory, builder);

LocalNode node = builder.build();
Node node = nodeOptions.getNode();

HttpHandler readinessCheck = req -> {
if (node.getStatus().hasCapacity()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ java_library(
"//java/server/src/org/openqa/selenium/concurrent",
"//java/server/src/org/openqa/selenium/events",
"//java/server/src/org/openqa/selenium/grid/component",
"//java/server/src/org/openqa/selenium/grid/config",
"//java/server/src/org/openqa/selenium/grid/data",
"//java/server/src/org/openqa/selenium/grid/docker",
"//java/server/src/org/openqa/selenium/grid/log",
"//java/server/src/org/openqa/selenium/grid/node",
"//java/server/src/org/openqa/selenium/grid/node/config",
"//java/server/src/org/openqa/selenium/grid/server",
"//java/server/src/org/openqa/selenium/grid/web",
artifact("com.google.guava:guava"),
],
Expand Down
Loading

0 comments on commit 62d3333

Please sign in to comment.