Skip to content

Commit

Permalink
Restructure the worker to be deterministic and hopefully clearer abou…
Browse files Browse the repository at this point in the history
…t the contract being honored.

Adds BazelWorkerTest that checks the behaviours of both invocation workers and persistent.
  • Loading branch information
Corbin Smith authored and cgruber committed Nov 16, 2020
1 parent 0779709 commit cc039db
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 46 deletions.
36 changes: 19 additions & 17 deletions src/main/kotlin/io/bazel/kotlin/builder/tasks/BazelWorker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.logging.Level.SEVERE
import java.util.logging.Level
import java.util.logging.Logger

/**
Expand Down Expand Up @@ -150,7 +150,7 @@ class PersistentWorker(
private val io: WorkerIO,
private val program: CommandLineProgram
) : Worker {
private val logger = Logger.getLogger(PersistentWorker::class.java.canonicalName)
val logger = Logger.getLogger(PersistentWorker::class.java.canonicalName)

enum class Status {
OK, INTERRUPTED, ERROR
Expand All @@ -163,15 +163,17 @@ class PersistentWorker(
val (status, exit) = WorkingDirectoryContext.newContext()
.runCatching {
request.argumentsList
?.let { maybeExpand(it) }
?.let {
maybeExpand(it)
}
.run {
Status.OK to program.apply(dir, maybeExpand(request.argumentsList))
}
}
.recover { e: Throwable ->
io.execution.write((e.message ?: e.toString()).toByteArray(UTF_8))
if (!e.wasInterrupted()) {
logger.log(SEVERE,
logger.log(Level.SEVERE,
"ERROR: Worker threw uncaught exception",
e)
Status.ERROR to 1
Expand All @@ -181,10 +183,10 @@ class PersistentWorker(
}
.getOrThrow()

val response = WorkResponse.newBuilder().apply {
val response = with(WorkResponse.newBuilder()) {
output = String(io.execution.toByteArray(), UTF_8)
exitCode = exit
requestId = request.requestId
setRequestId(request.requestId)
}.build()

// return the response
Expand All @@ -207,20 +209,20 @@ class InvocationWorker(
private val io: WorkerIO,
private val program: CommandLineProgram
) : Worker {
private val logger: Logger = Logger.getLogger(InvocationWorker::class.java.canonicalName)
override fun run(args: List<String>): Int = WorkingDirectoryContext.newContext()
.runCatching { program.apply(dir, maybeExpand(args)) }
.recover { e ->
logger.log(SEVERE,
val logger: Logger = Logger.getLogger(InvocationWorker::class.java.canonicalName)
override fun run(args: List<String>): Int {
return WorkingDirectoryContext.newContext().runCatching {
program.apply(dir, maybeExpand(args))
}.recover { e ->
logger.log(Level.SEVERE,
"ERROR: Worker threw uncaught exception with args: ${maybeExpand(args)}",
e)
return@recover 1 // return non-0 exitcode
}
.also {
// print execution log
1
}.also {
println(String(io.execution.toByteArray(), UTF_8))
}
.getOrDefault(0)
}.getOrDefault(0)
}
}



1 change: 0 additions & 1 deletion src/test/kotlin/io/bazel/kotlin/builder/tasks/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ kt_jvm_test(
]
)


# TODO(bazelbuild/rules_kotlin/issues/275): Remove full jar reference when the kt_rules_test handles jvm_import data better.
_MAVEN_CENTRAL_PREFIX = "@kotlin_rules_maven//:v1/https/maven-central.storage.googleapis.com/repos/central/data"

Expand Down
59 changes: 31 additions & 28 deletions src/test/kotlin/io/bazel/kotlin/builder/tasks/BazelWorkerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package io.bazel.kotlin.builder.tasks

import com.google.common.truth.Truth.assertThat
import com.google.common.truth.Truth.assertWithMessage
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse
import kotlinx.coroutines.CoroutineName
Expand All @@ -28,10 +27,10 @@ import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.sendBlocking
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.debug.junit4.CoroutinesTimeout
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runBlockingTest
import kotlinx.coroutines.yield
import org.junit.Rule
import org.junit.Test
import java.io.ByteArrayOutputStream
Expand Down Expand Up @@ -69,6 +68,8 @@ class BazelWorkerTest {

@Test
fun persistentWorker() {
val workerInput = WorkerChannel("in")
val workerOutput = WorkerChannel("out")
runBlockingTest {
val program = object : CommandLineProgram {
override fun apply(workingDir: Path, args: List<String>): Int {
Expand All @@ -82,27 +83,23 @@ class BazelWorkerTest {
}
}

val workerInput = WorkerChannel("in")
val workerOutput = WorkerChannel("out")
val execution = ByteArrayOutputStream()

val worker = GlobalScope.async(CoroutineName("worker")) {
val done = GlobalScope.async(CoroutineName("worker")) {
WorkerIO(workerInput.input,
PrintStream(workerOutput.output),
execution) {}
.use { io ->
PersistentWorker(io, program).run(listOf())
}
execution) {}.use { io ->
PersistentWorker(io, program).run(listOf())
}
}

// asserts scope to ensure all cases are run.
// asserts scope to ensure all asserts are run.
// messy, can be cleaned up -- since kotlin channels love to block and can easily starve a
// dispatcher, it's necessary to read each channel in a different coroutine.
// The coroutineScope prevents the assertions from happening outside of the expected
// asynchronicity.
coroutineScope {
launch {
assertWithMessage("worker is active").that(worker.isActive).isTrue()
workerInput.send(request(1, "ok"))
}
launch {
Expand All @@ -113,7 +110,6 @@ class BazelWorkerTest {

coroutineScope {
launch {
assertWithMessage("worker is active").that(worker.isActive).isTrue()
workerInput.send(request(2, "fail"))
}
launch {
Expand All @@ -124,7 +120,6 @@ class BazelWorkerTest {

coroutineScope {
launch {
assertWithMessage("worker is active").that(worker.isActive).isTrue()
workerInput.send(request(3, "error"))
}
launch {
Expand All @@ -136,7 +131,6 @@ class BazelWorkerTest {
// an interrupt kills the worker.
coroutineScope {
launch {
assertWithMessage("worker is active").that(worker.isActive).isTrue()
workerInput.send(request(4, "interrupt"))
workerInput.close()
}
Expand All @@ -146,20 +140,22 @@ class BazelWorkerTest {
}
}

assertThat(worker.await()).isEqualTo(0)

assertThat(done.await()).isEqualTo(0)
}
}

private fun request(id: Int, vararg args: String) = WorkRequest.newBuilder().apply {
requestId = id
private fun request(id: Int, vararg args: String) = with(WorkRequest.newBuilder()) {
setRequestId(id)
addAllArguments(args.asList())
}.build()

private fun response(id: Int, code: Int, out: String = "") = WorkResponse.newBuilder().apply {
exitCode = code
output = out
requestId = id
}.build()
private fun response(id: Int, exitCode: Int, output: String = "") =
with(WorkResponse.newBuilder()) {
setExitCode(exitCode)
setOutput(output)
setRequestId(id)
}.build()

/** WorkerChannel encapsulates the communication between the test and the worker. */
class WorkerChannel(
Expand All @@ -185,8 +181,10 @@ class BazelWorkerTest {
class ChannelInputStream(val channel: Channel<Byte>, val name: String) : InputStream() {
override fun read(): Int {
return runBlocking(CoroutineName("$name.read()")) {
if (channel.isEmpty) {
yield()
// since pipes block until the next event, this simulates that without starving
// other routines.
while (channel.isEmpty) {
delay(5L)
}
// read blocking -- this better simulates the java InputStream behaviour.
return@runBlocking channel.receive().toInt()
Expand All @@ -195,11 +193,14 @@ class BazelWorkerTest {

override fun read(b: ByteArray, off: Int, len: Int): Int {
return runBlocking(CoroutineName("$name.read(ByteArray,Int,Int)")) {
if (channel.isEmpty) {
yield()
// since pipes block until the next event, this simulates that without starving
// other routines.
while (channel.isEmpty) {
delay(5L)
}
val end = Math.min(b.size, off + len - 1)
var read = 0
for (i in off..b.size.coerceAtMost(off + len - 1)) {
for (i in off..end) {
val rb = channel.receive()
b[i] = rb
read++
Expand All @@ -217,8 +218,10 @@ class BazelWorkerTest {

override fun write(ba: ByteArray, off: Int, len: Int) {
runBlocking(CoroutineName("$name.write(ByteArray, Int, Int)")) {
for (i in off..ba.size.coerceAtMost(off + len - 1)) {
var sent = 0
for (i in off..Math.min(ba.size, off + len - 1)) {
channel.sendBlocking(ba[i])
sent++
}
}
}
Expand Down

0 comments on commit cc039db

Please sign in to comment.