Skip to content

Commit

Permalink
Showing 9 changed files with 158 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -53,6 +53,8 @@ final class WorkerKey {
private final boolean isSpeculative;
/** A WorkerProxy will be instantiated if true, instantiate a regular Worker if false. */
private final boolean proxied;
/** If true, the workers for this key are able to cancel work requests. */
private final boolean cancellable;
/**
* Cached value for the hash of this key, because the value is expensive to calculate
* (ImmutableMap and ImmutableList do not cache their hashcodes.
@@ -70,6 +72,7 @@ final class WorkerKey {
SortedMap<PathFragment, HashCode> workerFilesWithHashes,
boolean isSpeculative,
boolean proxied,
boolean cancellable,
WorkerProtocolFormat protocolFormat) {
this.args = Preconditions.checkNotNull(args);
this.env = Preconditions.checkNotNull(env);
@@ -79,8 +82,8 @@ final class WorkerKey {
this.workerFilesWithHashes = Preconditions.checkNotNull(workerFilesWithHashes);
this.isSpeculative = isSpeculative;
this.proxied = proxied;
this.cancellable = cancellable;
this.protocolFormat = protocolFormat;

hash = calculateHashCode();
}

@@ -128,6 +131,10 @@ public boolean isMultiplex() {
return getProxied() && !isSpeculative;
}

public boolean isCancellable() {
return cancellable;
}

/** Returns the format of the worker protocol. */
public WorkerProtocolFormat getProtocolFormat() {
return protocolFormat;
Original file line number Diff line number Diff line change
@@ -77,8 +77,6 @@ final class WorkerSpawnRunner implements SpawnRunner {
public static final String REASON_NO_FLAGFILE =
"because the command-line arguments do not contain at least one @flagfile or --flagfile=";
public static final String REASON_NO_TOOLS = "because the action has no tools";
public static final String REASON_NO_EXECUTION_INFO =
"because the action's execution info does not contain 'supports-workers=1'";

/** Pattern for @flagfile.txt and --flagfile=flagfile.txt */
private static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)");
@@ -205,6 +203,7 @@ public SpawnResult exec(Spawn spawn, SpawnExecutionContext context)
workerFiles,
context.speculating(),
multiplex && Spawns.supportsMultiplexWorkers(spawn),
Spawns.supportsWorkerCancellation(spawn),
protocolFormat);

SpawnMetrics.Builder spawnMetrics =
@@ -458,7 +457,11 @@ WorkResponse execInWorker(
try {
response = worker.getResponse(request.getRequestId());
} catch (InterruptedException e) {
finishWorkAsync(key, worker, request);
finishWorkAsync(
key,
worker,
request,
workerOptions.workerCancellation && Spawns.supportsWorkerCancellation(spawn));
worker = null;
throw e;
} catch (IOException e) {
@@ -480,6 +483,12 @@ WorkResponse execInWorker(
throw createEmptyResponseException(worker.getLogFile());
}

if (response.getWasCancelled()) {
throw createUserExecException(
"Received cancel response for " + response.getRequestId() + " without having cancelled",
Code.FINISH_FAILURE);
}

try {
Stopwatch processOutputsStopwatch = Stopwatch.createStarted();
context.lockOutputFiles();
@@ -525,12 +534,21 @@ WorkResponse execInWorker(
* interrupted. This takes ownership of the worker for purposes of returning it to the worker
* pool.
*/
private void finishWorkAsync(WorkerKey key, Worker worker, WorkRequest request) {
private void finishWorkAsync(
WorkerKey key, Worker worker, WorkRequest request, boolean canCancel) {
Thread reaper =
new Thread(
() -> {
Worker w = worker;
try {
if (canCancel) {
WorkRequest cancelRequest =
WorkRequest.newBuilder()
.setRequestId(request.getRequestId())
.setCancel(true)
.build();
w.putRequest(cancelRequest);
}
w.getResponse(request.getRequestId());
} catch (IOException | InterruptedException e1) {
// If this happens, we either can't trust the output of the worker, or we got
@@ -549,7 +567,8 @@ private void finishWorkAsync(WorkerKey key, Worker worker, WorkRequest request)
workers.returnObject(key, w);
}
}
});
},
"AsyncFinish-Worker-" + worker.workerId);
reaper.start();
}

Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
import com.google.devtools.build.lib.worker.ExampleWorkerOptions.ExampleWorkOptions;
import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkerMessageProcessor;
import com.google.devtools.build.lib.worker.WorkerProtocol.Input;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.common.options.OptionsParser;
@@ -42,12 +43,9 @@
import java.util.Map;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Semaphore;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import sun.misc.Signal;
import sun.misc.SignalHandler;

/** An example implementation of a worker process that is used for integration tests. */
public final class ExampleWorker {
@@ -70,6 +68,7 @@ public final class ExampleWorker {

// The options passed to this worker on a per-worker-lifetime basis.
static ExampleWorkerOptions workerOptions;
private static WorkerMessageProcessor messageProcessor;

private static class InterruptableWorkRequestHandler extends WorkRequestHandler {

@@ -118,7 +117,7 @@ public static void main(String[] args) throws Exception {
parser.parse(args);
workerOptions = parser.getOptions(ExampleWorkerOptions.class);
WorkerProtocolFormat protocolFormat = workerOptions.workerProtocol;
WorkRequestHandler.WorkerMessageProcessor messageProcessor = null;
messageProcessor = null;
switch (protocolFormat) {
case JSON:
messageProcessor =
@@ -147,21 +146,23 @@ private static int doWork(List<String> args, PrintWriter err) {
PrintStream originalStdOut = System.out;
PrintStream originalStdErr = System.err;

if (workerOptions.waitForSignal) {
Semaphore signalSem = new Semaphore(0);
Signal.handle(
new Signal("HUP"),
new SignalHandler() {
@Override
public void handle(Signal sig) {
signalSem.release();
}
});
if (workerOptions.waitForCancel) {
try {
signalSem.acquire();
} catch (InterruptedException e) {
System.out.println("Interrupted while waiting for signal");
e.printStackTrace();
WorkRequest workRequest = messageProcessor.readWorkRequest();
if (workRequest.getRequestId() != currentRequest.getRequestId()) {
System.err.format(
"Got cancel request for %d while expecting cancel request for %d%n",
workRequest.getRequestId(), currentRequest.getRequestId());
return 1;
}
if (!workRequest.getCancel()) {
System.err.format(
"Got non-cancel request for %d while expecting cancel request%n",
workRequest.getRequestId());
return 1;
}
} catch (IOException e) {
throw new RuntimeException("Exception while waiting for cancel request", e);
}
}
try (PrintStream ps = new PrintStream(baos)) {
Original file line number Diff line number Diff line change
@@ -136,12 +136,12 @@ public static class ExampleWorkOptions extends OptionsBase {
public boolean hardPoison;

@Option(
name = "wait_for_signal",
name = "wait_for_cancel",
documentationCategory = OptionDocumentationCategory.UNCATEGORIZED,
effectTags = {OptionEffectTag.NO_OP},
defaultValue = "false",
help = "Don't send a response until receiving a SIGXXXX.")
public boolean waitForSignal;
help = "Don't send a response until receiving a cancel request.")
public boolean waitForCancel;

/** Enum converter for --worker_protocol. */
public static class WorkerProtocolEnumConverter
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ static WorkerKey createWorkerKey(
/* workerFilesWithHashes= */ ImmutableSortedMap.of(),
/* mustBeSandboxed= */ false,
/* proxied= */ proxied,
/* cancellable= */ false,
WorkerProtocolFormat.PROTO);
}

@@ -58,6 +59,7 @@ static WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat, FileSystem
/* workerFilesWithHashes= */ ImmutableSortedMap.of(),
/* mustBeSandboxed= */ true,
/* proxied= */ true,
/* cancellable= */ false,
protocolFormat);
}

Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ protected WorkerKey createWorkerKey(boolean mustBeSandboxed, boolean proxied, St
/* workerFilesWithHashes= */ ImmutableSortedMap.of(),
/* mustBeSandboxed= */ mustBeSandboxed,
/* proxied= */ proxied,
/* cancellable= */ false,
WorkerProtocolFormat.PROTO);
}

Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ private WorkerKey makeWorkerKey(boolean multiplex, boolean dynamic) {
/* workerFilesWithHashes= */ ImmutableSortedMap.of(),
/* isSpeculative= */ dynamic,
/* proxied= */ multiplex,
/* cancellable=*/ false,
WorkerProtocolFormat.PROTO);
}

@@ -90,6 +91,7 @@ public void testWorkerKeyEquality() {
workerKey.getWorkerFilesWithHashes(),
workerKey.isSpeculative(),
workerKey.getProxied(),
workerKey.isCancellable(),
workerKey.getProtocolFormat());
assertThat(workerKey).isEqualTo(workerKeyWithSameFields);
}
@@ -107,6 +109,7 @@ public void testWorkerKeyInequality_protocol() {
workerKey.getWorkerFilesWithHashes(),
workerKey.isSpeculative(),
workerKey.getProxied(),
workerKey.isCancellable(),
WorkerProtocolFormat.JSON);
assertThat(workerKey).isNotEqualTo(workerKeyWithDifferentProtocol);
}
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@ public void instanceCreationRemovalTest() throws Exception {
ImmutableSortedMap.of(),
false,
false,
/* cancellable= */ false,
WorkerProtocolFormat.PROTO);
WorkerMultiplexer wm1 = WorkerMultiplexerManager.getInstance(workerKey1, logFile);

@@ -77,6 +78,7 @@ public void instanceCreationRemovalTest() throws Exception {
ImmutableSortedMap.of(),
false,
false,
/* cancellable= */ false,
WorkerProtocolFormat.PROTO);
WorkerMultiplexer wm2 = WorkerMultiplexerManager.getInstance(workerKey2, logFile);

Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.lib.worker.TestUtils.createWorkerKey;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -25,6 +26,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.actions.ExecException;
import com.google.devtools.build.lib.actions.ExecutionRequirements;
import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
import com.google.devtools.build.lib.actions.MetadataProvider;
import com.google.devtools.build.lib.actions.ResourceManager;
@@ -45,14 +47,17 @@
import com.google.devtools.build.lib.vfs.FileSystemUtils;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import java.io.IOException;
import java.util.concurrent.Semaphore;
import org.apache.commons.pool2.PooledObject;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -112,7 +117,6 @@ public void testExecInWorker_happyPath() throws ExecException, InterruptedExcept
new WorkerOptions());
WorkerKey key = createWorkerKey(fs, "mnem", false);
Path logFile = fs.getPath("/worker.log");
when(worker.getLogFile()).thenReturn(logFile);
when(worker.getResponse(0))
.thenReturn(WorkResponse.newBuilder().setExitCode(0).setOutput("out").build());
WorkResponse response =
@@ -134,12 +138,102 @@ public void testExecInWorker_happyPath() throws ExecException, InterruptedExcept
verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker");
}

@Test
public void testExecInWorker_finishesAsyncOnInterrupt() throws InterruptedException, IOException {
WorkerSpawnRunner runner =
new WorkerSpawnRunner(
new SandboxHelpers(false),
fs.getPath("/execRoot"),
createWorkerPool(),
/* multiplex */ false,
reporter,
localEnvProvider,
/* binTools */ null,
resourceManager,
/* runfilesTreeUpdater=*/ null,
new WorkerOptions());
WorkerKey key = createWorkerKey(fs, "mnem", false);
Path logFile = fs.getPath("/worker.log");
when(worker.getResponse(anyInt()))
.thenThrow(new InterruptedException())
.thenReturn(WorkResponse.newBuilder().setRequestId(2).build());
assertThrows(
InterruptedException.class,
() ->
runner.execInWorker(
spawn,
key,
context,
new SandboxInputs(ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of()),
SandboxOutputs.create(ImmutableSet.of(), ImmutableSet.of()),
ImmutableList.of(),
inputFileCache,
spawnMetrics));
assertThat(logFile.exists()).isFalse();
verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker");
verify(worker, times(1)).putRequest(WorkRequest.newBuilder().setRequestId(0).build());
}

@Test
public void testExecInWorker_sendsCancelMessageOnInterrupt()
throws ExecException, InterruptedException, IOException {
WorkerOptions workerOptions = new WorkerOptions();
workerOptions.workerCancellation = true;
when(spawn.getExecutionInfo())
.thenReturn(ImmutableMap.of(ExecutionRequirements.SUPPORTS_WORKER_CANCELLATION, "1"));
WorkerSpawnRunner runner =
new WorkerSpawnRunner(
new SandboxHelpers(false),
fs.getPath("/execRoot"),
createWorkerPool(),
/* multiplex */ false,
reporter,
localEnvProvider,
/* binTools */ null,
resourceManager,
/* runfilesTreeUpdater=*/ null,
workerOptions);
WorkerKey key = createWorkerKey(fs, "mnem", false);
Path logFile = fs.getPath("/worker.log");
Semaphore secondResponseRequested = new Semaphore(0);
when(worker.getResponse(anyInt()))
.thenThrow(new InterruptedException())
.thenAnswer(
invocation -> {
secondResponseRequested.release();
return WorkResponse.newBuilder()
.setRequestId(invocation.getArgument(0))
.setWasCancelled(true)
.build();
});
assertThrows(
InterruptedException.class,
() ->
runner.execInWorker(
spawn,
key,
context,
new SandboxInputs(ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of()),
SandboxOutputs.create(ImmutableSet.of(), ImmutableSet.of()),
ImmutableList.of(),
inputFileCache,
spawnMetrics));
secondResponseRequested.acquire();
assertThat(logFile.exists()).isFalse();
verify(context, times(1)).report(ProgressStatus.EXECUTING, "worker");
ArgumentCaptor<WorkRequest> argumentCaptor = ArgumentCaptor.forClass(WorkRequest.class);
verify(worker, times(2)).putRequest(argumentCaptor.capture());
assertThat(argumentCaptor.getAllValues().get(0))
.isEqualTo(WorkRequest.newBuilder().setRequestId(0).build());
assertThat(argumentCaptor.getAllValues().get(1))
.isEqualTo(WorkRequest.newBuilder().setRequestId(0).setCancel(true).build());
}

@Test
public void testExecInWorker_noMultiplexWithDynamic()
throws ExecException, InterruptedException, IOException {
WorkerOptions workerOptions = new WorkerOptions();
workerOptions.workerMultiplex = true;
when(context.speculating()).thenReturn(true);
WorkerSpawnRunner runner =
new WorkerSpawnRunner(
new SandboxHelpers(false),
@@ -155,7 +249,6 @@ public void testExecInWorker_noMultiplexWithDynamic()
// This worker key just so happens to be multiplex and require sandboxing.
WorkerKey key = createWorkerKey(WorkerProtocolFormat.JSON, fs);
Path logFile = fs.getPath("/worker.log");
when(worker.getLogFile()).thenReturn(logFile);
when(worker.getResponse(0))
.thenReturn(
WorkResponse.newBuilder().setExitCode(0).setRequestId(0).setOutput("out").build());

0 comments on commit e9e6978

Please sign in to comment.