Skip to content

Commit

Permalink
[cdp] Intercept requests and responses in NetworkInterceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
shs96c committed Sep 8, 2021
1 parent 4ac8da5 commit 8a2e777
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 150 deletions.
149 changes: 94 additions & 55 deletions java/src/org/openqa/selenium/devtools/idealized/Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,47 @@

import org.openqa.selenium.Credentials;
import org.openqa.selenium.UsernameAndPassword;
import org.openqa.selenium.WebDriverException;
import org.openqa.selenium.devtools.Command;
import org.openqa.selenium.devtools.DevTools;
import org.openqa.selenium.devtools.DevToolsException;
import org.openqa.selenium.devtools.Event;
import org.openqa.selenium.internal.Either;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.remote.http.Contents;
import org.openqa.selenium.remote.http.Filter;
import org.openqa.selenium.remote.http.HttpMethod;
import org.openqa.selenium.remote.http.HttpRequest;
import org.openqa.selenium.remote.http.HttpResponse;
import org.openqa.selenium.remote.http.Routable;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Base64;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Logger;

import static java.net.HttpURLConnection.HTTP_OK;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.logging.Level.WARNING;

public abstract class Network<AUTHREQUIRED, REQUESTPAUSED> {

private static final Logger LOG = Logger.getLogger(Network.class.getName());

private final Map<Predicate<URI>, Supplier<Credentials>> authHandlers = new LinkedHashMap<>();
private final Map<Predicate<HttpRequest>, Function<HttpRequest, HttpResponse>> uriHandlers = new LinkedHashMap<>();
private final Filter defaultFilter = next -> next::execute;
private Filter filter = defaultFilter;
protected final DevTools devTools;
private boolean interceptingTraffic = false;

public Network(DevTools devtools) {
this.devTools = Require.nonNull("DevTools", devtools);
Expand All @@ -55,8 +70,7 @@ public void disable() {
devTools.send(enableNetworkCaching());

authHandlers.clear();
uriHandlers.clear();
interceptingTraffic = false;
filter = defaultFilter;
}

public static class UserAgent {
Expand Down Expand Up @@ -113,35 +127,19 @@ public void addAuthHandler(Predicate<URI> whenThisMatches, Supplier<Credentials>
prepareToInterceptTraffic();
}

public OpaqueKey addRequestHandler(Routable routable) {
Require.nonNull("Routable", routable);

return addRequestHandler(routable::matches, routable::execute);
@SuppressWarnings("SuspiciousMethodCalls")
public void resetNetworkFilter() {
filter = defaultFilter;
}

public OpaqueKey addRequestHandler(Predicate<HttpRequest> whenThisMatches, Function<HttpRequest, HttpResponse> returnThis) {
Require.nonNull("Request predicate", whenThisMatches);
Require.nonNull("Handler", returnThis);

uriHandlers.put(whenThisMatches, returnThis);
public void interceptTrafficWith(Filter filter) {
Require.nonNull("HTTP filter", filter);

this.filter = filter;
prepareToInterceptTraffic();

return new OpaqueKey(whenThisMatches);
}

@SuppressWarnings("SuspiciousMethodCalls")
public void removeRequestHandler(OpaqueKey key) {
Require.nonNull("Key", key);

uriHandlers.remove(key.getValue());
}

private void prepareToInterceptTraffic() {
if (interceptingTraffic) {
return;
}

public void prepareToInterceptTraffic() {
devTools.send(disableNetworkCaching());

devTools.addListener(
Expand Down Expand Up @@ -169,35 +167,56 @@ private void prepareToInterceptTraffic() {
devTools.send(cancelAuth(authRequired));
});

Map<String, CompletableFuture<HttpResponse>> responses = new ConcurrentHashMap<>();

devTools.addListener(
requestPausedEvent(),
pausedRequest -> {
Optional<HttpRequest> req = createHttpRequest(pausedRequest);
String id = getRequestId(pausedRequest);
Either<HttpRequest, HttpResponse> message = createSeMessages(pausedRequest);

if (!req.isPresent()) {
devTools.send(continueWithoutModification(pausedRequest));
return;
}
if (message.isRight()) {
HttpResponse res = message.right();
CompletableFuture<HttpResponse> future = responses.remove(id);

Optional<HttpResponse> maybeRes = getHttpResponse(req.get());
if (!maybeRes.isPresent()) {
devTools.send(continueWithoutModification(pausedRequest));
if (future == null) {
devTools.send(continueWithoutModification(pausedRequest));
return;
}

future.complete(res);
return;
}

HttpResponse response = maybeRes.get();
HttpResponse forBrowser = filter.andFinally(req -> {
// Convert the selenium request to a CDP one and fulfill.

CompletableFuture<HttpResponse> res = new CompletableFuture<>();
responses.put(id, res);

devTools.send(continueRequest(pausedRequest, req));

// Wait for the CDP response and send that back.
try {
return res.get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new WebDriverException(e);
} catch (ExecutionException e) {
LOG.log(WARNING, e, () -> "Unable to process request");
return new HttpResponse();
}
}).execute(message.left());

if ("Continue".equals(response.getHeader("Selenium-Interceptor"))) {
if ("Continue".equals(forBrowser.getHeader("Selenium-Interceptor"))) {
devTools.send(continueWithoutModification(pausedRequest));
return;
}

devTools.send(createResponse(pausedRequest, response));
devTools.send(fulfillRequest(pausedRequest, forBrowser));
});

devTools.send(enableFetchForAllPatterns());

interceptingTraffic = true;
}

protected Optional<Credentials> getAuthCredentials(URI uri) {
Expand All @@ -210,16 +229,6 @@ protected Optional<Credentials> getAuthCredentials(URI uri) {
.findFirst();
}

protected Optional<HttpResponse> getHttpResponse(HttpRequest forRequest) {
Require.nonNull("Request", forRequest);

return uriHandlers.entrySet().stream()
.filter(entry -> entry.getKey().test(forRequest))
.map(Map.Entry::getValue)
.map(func -> func.apply(forRequest))
.findFirst();
}

protected HttpMethod convertFromCdpHttpMethod(String method) {
Require.nonNull("HTTP Method", method);
try {
Expand All @@ -230,6 +239,32 @@ protected HttpMethod convertFromCdpHttpMethod(String method) {
}
}

protected HttpResponse createHttpResponse(
Optional<Integer> statusCode,
String body,
Boolean bodyIsBase64Encoded,
List<Map.Entry<String, String>> headers) {
Supplier<InputStream> content;
if (bodyIsBase64Encoded != null && bodyIsBase64Encoded) {
byte[] decoded = Base64.getDecoder().decode(body);
content = () -> new ByteArrayInputStream(decoded);
} else {
content = Contents.string(body, UTF_8);
}

HttpResponse res = new HttpResponse()
.setStatus(statusCode.orElse(HTTP_OK))
.setContent(content);

headers.forEach(entry -> {
if (entry.getValue() != null) {
res.addHeader(entry.getKey(), entry.getValue());
}
});

return res;
}

protected HttpRequest createHttpRequest(
String cdpMethod,
String url,
Expand Down Expand Up @@ -262,9 +297,13 @@ protected HttpRequest createHttpRequest(

protected abstract Event<REQUESTPAUSED> requestPausedEvent();

protected abstract Optional<HttpRequest> createHttpRequest(REQUESTPAUSED pausedRequest);
protected abstract String getRequestId(REQUESTPAUSED pausedReq);

protected abstract Either<HttpRequest, HttpResponse> createSeMessages(REQUESTPAUSED pausedReq);

protected abstract Command<Void> continueWithoutModification(REQUESTPAUSED pausedReq);

protected abstract Command<Void> continueWithoutModification(REQUESTPAUSED pausedRequest);
protected abstract Command<Void> continueRequest(REQUESTPAUSED pausedReq, HttpRequest req);

protected abstract Command<Void> createResponse(REQUESTPAUSED pausedRequest, HttpResponse response);
protected abstract Command<Void> fulfillRequest(REQUESTPAUSED pausedReq, HttpResponse res);
}
95 changes: 73 additions & 22 deletions java/src/org/openqa/selenium/devtools/v85/V85Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.openqa.selenium.devtools.v85;

import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteStreams;
import org.openqa.selenium.UsernameAndPassword;
import org.openqa.selenium.devtools.Command;
import org.openqa.selenium.devtools.DevTools;
Expand All @@ -29,14 +30,20 @@
import org.openqa.selenium.devtools.v85.fetch.model.HeaderEntry;
import org.openqa.selenium.devtools.v85.fetch.model.RequestPattern;
import org.openqa.selenium.devtools.v85.fetch.model.RequestPaused;
import org.openqa.selenium.devtools.v85.fetch.model.RequestStage;
import org.openqa.selenium.devtools.v85.network.model.Request;
import org.openqa.selenium.remote.http.Contents;
import org.openqa.selenium.internal.Either;
import org.openqa.selenium.remote.http.HttpRequest;
import org.openqa.selenium.remote.http.HttpResponse;

import java.util.ArrayList;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.AbstractMap;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class V85Network extends Network<AuthRequired, RequestPaused> {
Expand Down Expand Up @@ -64,7 +71,9 @@ protected Command<Void> disableNetworkCaching() {
@Override
protected Command<Void> enableFetchForAllPatterns() {
return Fetch.enable(
Optional.of(ImmutableList.of(new RequestPattern(Optional.of("*"), Optional.empty(), Optional.empty()))),
Optional.of(ImmutableList.of(
new RequestPattern(Optional.of("*"), Optional.empty(), Optional.of(RequestStage.REQUEST)),
new RequestPattern(Optional.of("*"), Optional.empty(), Optional.of(RequestStage.RESPONSE)))),
Optional.of(true));
}

Expand Down Expand Up @@ -101,23 +110,42 @@ protected Command<Void> cancelAuth(AuthRequired authRequired) {
}

@Override
protected Event<RequestPaused> requestPausedEvent() {
public Event<RequestPaused> requestPausedEvent() {
return Fetch.requestPaused();
}

@Override
protected Optional<HttpRequest> createHttpRequest(RequestPaused pausedRequest) {
if (pausedRequest.getResponseErrorReason().isPresent() || pausedRequest.getResponseStatusCode().isPresent()) {
return Optional.empty();
public Either<HttpRequest, HttpResponse> createSeMessages(RequestPaused pausedReq) {
if (pausedReq.getResponseStatusCode().isPresent() || pausedReq.getResponseErrorReason().isPresent()) {
Fetch.GetResponseBodyResponse base64Body = devTools.send(Fetch.getResponseBody(pausedReq.getRequestId()));

List<Map.Entry<String, String>> headers = new LinkedList<>();
pausedReq.getResponseHeaders().ifPresent(resHeaders ->
resHeaders.forEach(header -> headers.add(new AbstractMap.SimpleEntry<>(header.getName(), header.getValue()))));

HttpResponse res = createHttpResponse(
pausedReq.getResponseStatusCode(),
base64Body.getBody(),
base64Body.getBase64Encoded(),
headers);

return Either.right(res);
}

Request cdpRequest = pausedRequest.getRequest();
Request cdpReq = pausedReq.getRequest();

HttpRequest req = createHttpRequest(
cdpReq.getMethod(),
cdpReq.getUrl(),
cdpReq.getHeaders(),
cdpReq.getPostData());

return Optional.of(createHttpRequest(
cdpRequest.getMethod(),
cdpRequest.getUrl(),
cdpRequest.getHeaders(),
cdpRequest.getPostData()));
return Either.left(req);
}

@Override
protected String getRequestId(RequestPaused pausedReq) {
return pausedReq.getRequestId().toString();
}

@Override
Expand All @@ -131,20 +159,43 @@ protected Command<Void> continueWithoutModification(RequestPaused pausedRequest)
}

@Override
protected Command<Void> createResponse(RequestPaused pausedRequest, HttpResponse response) {
List<HeaderEntry> headers = new ArrayList<>();
response.getHeaderNames().forEach(
name -> response.getHeaders(name).forEach(value -> headers.add(new HeaderEntry(name, value))));
protected Command<Void> continueRequest(RequestPaused pausedReq, HttpRequest req) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (InputStream is = req.getContent().get()) {
ByteStreams.copy(is, bos);
} catch (IOException e) {
return continueWithoutModification(pausedReq);
}

byte[] bytes = Contents.bytes(response.getContent());
String body = bytes.length > 0 ? Base64.getEncoder().encodeToString(bytes) : null;
List<HeaderEntry> headers = new LinkedList<>();
req.getHeaderNames().forEach(name -> req.getHeaders(name).forEach(value -> headers.add(new HeaderEntry(name, value))));

return Fetch.continueRequest(
pausedReq.getRequestId(),
Optional.empty(),
Optional.of(req.getMethod().toString()),
Optional.of(Base64.getEncoder().encodeToString(bos.toByteArray())),
Optional.of(headers));
}

@Override
protected Command<Void> fulfillRequest(RequestPaused pausedReq, HttpResponse res) {
List<HeaderEntry> headers = new LinkedList<>();
res.getHeaderNames().forEach(name -> res.getHeaders(name).forEach(value -> headers.add(new HeaderEntry(name, value))));

ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (InputStream is = res.getContent().get()) {
ByteStreams.copy(is, bos);
} catch (IOException e) {
bos.reset();
}

return Fetch.fulfillRequest(
pausedRequest.getRequestId(),
response.getStatus(),
pausedReq.getRequestId(),
res.getStatus(),
Optional.of(headers),
Optional.empty(),
Optional.ofNullable(body),
Optional.of(Base64.getEncoder().encodeToString(bos.toByteArray())),
Optional.empty());
}
}
Loading

0 comments on commit 8a2e777

Please sign in to comment.