Skip to content

Commit

Permalink
Parallelize stdout and clean up PersistentWorker (#501)
Browse files Browse the repository at this point in the history
* Parallelize stdout and refactor PersistentWorker

* Serialize writing to stdout

* Dont use flow
  • Loading branch information
jeffzoch authored Mar 12, 2021
1 parent af9596b commit 517d30a
Showing 1 changed file with 57 additions and 65 deletions.
122 changes: 57 additions & 65 deletions src/main/kotlin/io/bazel/worker/PersistentWorker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,17 @@ import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import java.io.PrintStream
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.Executors
import kotlin.coroutines.CoroutineContext
Expand All @@ -50,81 +60,63 @@ class PersistentWorker(

constructor() : this(Dispatchers.IO, IO.Companion::capture)

/**
* ThreadAwareDispatchers provides an ability to separate thread blocking operations from coroutines..
*
* Coroutines interleave actions over a pool of threads. When an action blocks it stands a chance
* of producing a deadlock. We sidestep this by providing a separate dispatcher to contain
* blocking operations, like reading from a stream. Inelegant, and a bit of a sledgehammer, but
* safe for the moment.
*/
private class BlockableDispatcher(
private val unblockedContext: CoroutineContext,
private val blockingContext: ExecutorCoroutineDispatcher,
scope: CoroutineScope
) : CoroutineScope by scope {
companion object {
fun <T> runIn(
owningContext: CoroutineContext,
exec: suspend BlockableDispatcher.() -> T
) =
Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher ->
runBlocking(owningContext) { BlockableDispatcher(owningContext, dispatcher, this).exec() }
}
}

fun <T> blockable(action: () -> T): T {
return runBlocking(blockingContext) {
return@runBlocking action()
}
}
}

@ExperimentalCoroutinesApi
override fun start(execute: Work) = WorkerContext.run {
//Use channel to serialize writing output
val writeChannel = Channel<WorkerProtocol.WorkResponse>(UNLIMITED)
captureIO().use { io ->
BlockableDispatcher.runIn(coroutineContext) {
blockable {
generateSequence { WorkRequest.parseDelimitedFrom(io.input) }
}.asFlow()
.map { request ->
info { "received req: ${request.requestId}" }
async {
doTask("request ${request.requestId}") { ctx ->
request.argumentsList.run {
execute(ctx, toList())
runBlocking {
//Parent coroutine to track all of children and close channel on completion
launch(Dispatchers.Default) {
generateSequence { WorkRequest.parseDelimitedFrom(io.input) }
.forEach { request ->
launch {
compileWork(request, io, writeChannel, execute)
}
}.let { result ->
info { "task result ${result.status}" }
WorkerProtocol.WorkResponse.newBuilder().apply {
output =
listOf(
result.log.out.toString(),
io.captured.toByteArray().toString(UTF_8)
).filter { it.isNotBlank() }.joinToString("\n")
exitCode = result.status.exit
requestId = request.requestId
}.build()
}
}
}
.buffer()
.map { deferred ->
deferred.await()
}
.collect { response ->
blockable {
info {
response.toString()
}
response.writeDelimitedTo(io.output)
io.output.flush()
}
}
}.invokeOnCompletion { writeChannel.close() }

writeChannel.consumeAsFlow()
.collect { response -> writeOutput(response, io.output) }
}

io.output.close()
info { "stopped worker" }
}
return@run 0
}

private suspend fun WorkerContext.compileWork(
request: WorkRequest,
io: IO,
chan: Channel<WorkerProtocol.WorkResponse>,
execute: Work
) = withContext(Dispatchers.Default) {
val result = doTask("request ${request.requestId}") { ctx ->
request.argumentsList.run {
execute(ctx, toList())
}
}
info { "task result ${result.status}" }
val response = WorkerProtocol.WorkResponse.newBuilder().apply {
output = listOf(
result.log.out.toString(),
io.captured.toByteArray().toString(UTF_8)
).filter { it.isNotBlank() }.joinToString("\n")
exitCode = result.status.exit
requestId = request.requestId
}.build()
info {
response.toString()
}
chan.send(response)
}

private suspend fun writeOutput(response: WorkerProtocol.WorkResponse, output: PrintStream) =
withContext(Dispatchers.IO) {
response.writeDelimitedTo(output)
output.flush()
}

}

0 comments on commit 517d30a

Please sign in to comment.