Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lwc: keep input to TimeGrouped batched #1490

Merged
merged 2 commits into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,12 @@ object LwcMessages {
data(i + 1) = parser.nextTextValue()
i += 2
}
SortedTagMap.createUnsafe(data, data.length)
// The map should be sorted from the server side, so we can avoid resorting
// here. Force the hash to be computed and cached as soon as possible so it
// can be done on the compute pool for parsing.
val tags = SortedTagMap.createUnsafe(data, data.length)
tags.hashCode
tags
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ import org.reactivestreams.Processor
import org.reactivestreams.Publisher
import org.slf4j.LoggerFactory

import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -94,6 +99,24 @@ private[stream] abstract class EvaluatorImpl(
// Counter for message that cannot be parsed
private val badMessages = registry.counter("atlas.eval.badMessages")

// Number of threads to use for parsing payloads
private val parsingNumThreads = math.max(Runtime.getRuntime.availableProcessors() / 2, 2)

// Execution context to use for parsing payloads coming back from lwcapi service
private val parsingEC = {
val threadCount = new AtomicInteger()
val factory = new ThreadFactory {
override def newThread(r: Runnable): Thread = {
val name = s"AtlasEvalParsing-${threadCount.getAndIncrement()}"
val thread = new Thread(r, name)
thread.setDaemon(true)
thread
}
}
val executor = Executors.newFixedThreadPool(parsingNumThreads, factory)
ExecutionContext.fromExecutor(executor)
}

private def newStreamContext(dsLogger: DataSourceLogger = (_, _) => ()): StreamContext = {
new StreamContext(
config,
Expand Down Expand Up @@ -233,10 +256,14 @@ private[stream] abstract class EvaluatorImpl(
val finalEvalInput = builder.add(Merge[AnyRef](2))

val intermediateEval = createInputFlow(context)
.via(context.monitorFlow("10_InputLines"))
.via(context.monitorFlow("10_InputBatches"))
.via(new LwcToAggrDatapoint(context))
.groupBy(Int.MaxValue, _.step, allowClosedSubstreamRecreation = true)
.flatMapConcat { vs =>
Source(vs.groupBy(_.step).map(_._2.toList))
}
.groupBy(Int.MaxValue, _.head.step, allowClosedSubstreamRecreation = true)
.via(new TimeGrouped(context))
.flatMapConcat(Source.apply)
.mergeSubstreams
.via(context.monitorFlow("11_GroupedDatapoints"))

Expand Down Expand Up @@ -400,7 +427,7 @@ private[stream] abstract class EvaluatorImpl(

private[stream] def createInputFlow(
context: StreamContext
): Flow[DataSources, AnyRef, NotUsed] = {
): Flow[DataSources, List[AnyRef], NotUsed] = {

val g = GraphDSL.create() { implicit builder =>
import GraphDSL.Implicits._
Expand All @@ -410,7 +437,7 @@ private[stream] abstract class EvaluatorImpl(

// Merge the data coming from remote and local before performing
// the time grouping and aggregation
val inputMerge = builder.add(Merge[AnyRef](2))
val inputMerge = builder.add(Merge[List[AnyRef]](2))

// Streams for remote (lwc-api cluster)
val remoteFlow =
Expand All @@ -420,7 +447,7 @@ private[stream] abstract class EvaluatorImpl(
val localFlow = Flow[DataSources]
.flatMapMerge(Int.MaxValue, s => Source(s.getSources.asScala.toList))
.flatMapMerge(Int.MaxValue, s => context.localSource(Uri(s.getUri)))
.flatMapConcat(parseMessage)
.map(parseMessage)

// Broadcast to remote/local flow, process and merge
dataSourcesBroadcast.out(0).map(_.remoteOnly()) ~> remoteFlow ~> inputMerge.in(0)
Expand Down Expand Up @@ -449,7 +476,7 @@ private[stream] abstract class EvaluatorImpl(
// Streams via WebSocket API `/api/v1/subscribe`, from each instance of lwc-api cluster
private def createClusterStreamFlow(
context: StreamContext
): Flow[SourcesAndGroups, AnyRef, NotUsed] = {
): Flow[SourcesAndGroups, List[AnyRef], NotUsed] = {
Flow[SourcesAndGroups]
.via(StreamOps.unique())
.flatMapConcat { sourcesAndGroups =>
Expand All @@ -469,7 +496,7 @@ private[stream] abstract class EvaluatorImpl(

private def createGroupByContext(
context: StreamContext
): ClusterOps.GroupByContext[Instance, Set[LwcExpression], AnyRef] = {
): ClusterOps.GroupByContext[Instance, Set[LwcExpression], List[AnyRef]] = {
ClusterOps.GroupByContext(
instance => createWebSocketFlow(instance),
registry,
Expand All @@ -487,7 +514,7 @@ private[stream] abstract class EvaluatorImpl(

private def createWebSocketFlow(
instance: EurekaSource.Instance
): Flow[Set[LwcExpression], AnyRef, NotUsed] = {
): Flow[Set[LwcExpression], List[AnyRef], NotUsed] = {
val base = instance.substitute("ws://{local-ipv4}:{port}")
val id = UUID.randomUUID().toString
val uri = s"$base/api/v$lwcapiVersion/subscribe/$id"
Expand All @@ -502,35 +529,38 @@ private[stream] abstract class EvaluatorImpl(
case _: TextMessage =>
throw new MatchError("text messages are not supported")
case BinaryMessage.Strict(str) =>
parseBatch(str)
Source.single(str)
case msg: BinaryMessage =>
msg.dataStream.fold(ByteString.empty)(_ ++ _).flatMapConcat(parseBatch)
msg.dataStream.fold(ByteString.empty)(_ ++ _)
}
.mapAsyncUnordered(parsingNumThreads) { msg =>
Future(parseBatch(msg))(parsingEC)
}
.mapMaterializedValue(_ => NotUsed)
}

private def parseBatch(message: ByteString): Source[AnyRef, NotUsed] = {
private def parseBatch(message: ByteString): List[AnyRef] = {
try {
ReplayLogging.log(message)
Source(LwcMessages.parseBatch(message))
LwcMessages.parseBatch(message)
} catch {
case e: Exception =>
logger.warn(s"failed to process message [$message]", e)
badMessages.increment()
Source.empty
List.empty
}
}

private def parseMessage(message: ByteString): Source[AnyRef, NotUsed] = {
private def parseMessage(message: ByteString): List[AnyRef] = {
try {
ReplayLogging.log(message)
Source.single(LwcMessages.parse(message))
List(LwcMessages.parse(message))
} catch {
case e: Exception =>
val messageString = toString(message)
logger.warn(s"failed to process message [$messageString]", e)
badMessages.increment()
Source.empty
List.empty
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ import com.netflix.atlas.eval.model.LwcSubscription
* [[AggrDatapoint]]s that can be used for evaluation.
*/
private[stream] class LwcToAggrDatapoint(context: StreamContext)
extends GraphStage[FlowShape[AnyRef, AggrDatapoint]] {
extends GraphStage[FlowShape[List[AnyRef], List[AggrDatapoint]]] {

private val in = Inlet[AnyRef]("LwcToAggrDatapoint.in")
private val out = Outlet[AggrDatapoint]("LwcToAggrDatapoint.out")
private val in = Inlet[List[AnyRef]]("LwcToAggrDatapoint.in")
private val out = Outlet[List[AggrDatapoint]]("LwcToAggrDatapoint.out")

override val shape: FlowShape[AnyRef, AggrDatapoint] = FlowShape(in, out)
override val shape: FlowShape[List[AnyRef], List[AggrDatapoint]] = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = {
new GraphStageLogic(shape) with InHandler with OutHandler {
Expand All @@ -51,13 +51,19 @@ private[stream] class LwcToAggrDatapoint(context: StreamContext)
private var nextSource: Int = 0

override def onPush(): Unit = {
grab(in) match {
val builder = List.newBuilder[AggrDatapoint]
grab(in).foreach {
case sb: LwcSubscription => updateState(sb)
case dp: LwcDatapoint => pushDatapoint(dp)
case dp: LwcDatapoint => builder ++= pushDatapoint(dp)
case dg: LwcDiagnosticMessage => pushDiagnosticMessage(dg)
case hb: LwcHeartbeat => pushHeartbeat(hb)
case _ => pull(in)
case hb: LwcHeartbeat => builder += pushHeartbeat(hb)
case _ =>
}
val datapoints = builder.result()
if (datapoints.isEmpty)
pull(in)
else
push(out, datapoints)
}

private def updateState(sub: LwcSubscription): Unit = {
Expand All @@ -66,34 +72,26 @@ private[stream] class LwcToAggrDatapoint(context: StreamContext)
state.put(m.id, m)
}
}
pull(in)
}

private def pushDatapoint(dp: LwcDatapoint): Unit = {
state.get(dp.id) match {
case Some(sub) =>
// TODO, put in source, for now make it random to avoid dedup
nextSource += 1
val expr = sub.expr
val step = sub.step
push(
out,
AggrDatapoint(dp.timestamp, step, expr, nextSource.toString, dp.tags, dp.value)
)
case None =>
pull(in)
private def pushDatapoint(dp: LwcDatapoint): Option[AggrDatapoint] = {
state.get(dp.id).map { sub =>
// TODO, put in source, for now make it random to avoid dedup
nextSource += 1
val expr = sub.expr
val step = sub.step
AggrDatapoint(dp.timestamp, step, expr, nextSource.toString, dp.tags, dp.value)
}
}

private def pushDiagnosticMessage(diagMsg: LwcDiagnosticMessage): Unit = {
state.get(diagMsg.id).foreach { sub =>
context.log(sub.expr, diagMsg.message)
}
pull(in)
}

private def pushHeartbeat(hb: LwcHeartbeat): Unit = {
push(out, AggrDatapoint.heartbeat(hb.timestamp, hb.step))
private def pushHeartbeat(hb: LwcHeartbeat): AggrDatapoint = {
AggrDatapoint.heartbeat(hb.timestamp, hb.step)
}

override def onPull(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ import com.netflix.atlas.eval.model.TimeGroup
*/
private[stream] class TimeGrouped(
context: StreamContext
) extends GraphStage[FlowShape[AggrDatapoint, TimeGroup]] {
) extends GraphStage[FlowShape[List[AggrDatapoint], List[TimeGroup]]] {

type AggrMap = scala.collection.mutable.AnyRefMap[DataExpr, AggrDatapoint.Aggregator]
type AggrMap = java.util.HashMap[DataExpr, AggrDatapoint.Aggregator]

/**
* Number of time buffers to maintain. The buffers are stored in a rolling array
Expand All @@ -60,10 +60,10 @@ private[stream] class TimeGrouped(
context.registry
)

private val in = Inlet[AggrDatapoint]("TimeGrouped.in")
private val out = Outlet[TimeGroup]("TimeGrouped.out")
private val in = Inlet[List[AggrDatapoint]]("TimeGrouped.in")
private val out = Outlet[List[TimeGroup]]("TimeGrouped.out")

override val shape: FlowShape[AggrDatapoint, TimeGroup] = FlowShape(in, out)
override val shape: FlowShape[List[AggrDatapoint], List[TimeGroup]] = FlowShape(in, out)

private val metricName = "atlas.eval.datapoints"
private val registry = context.registry
Expand Down Expand Up @@ -109,9 +109,11 @@ private[stream] class TimeGrouped(
*/
private def aggregate(i: Int, v: AggrDatapoint): Unit = {
if (!v.isHeartbeat) {
buf(i).get(v.expr) match {
case Some(aggr) => aggr.aggregate(v)
case None => buf(i).put(v.expr, AggrDatapoint.newAggregator(v, aggrSettings))
val aggr = buf(i).get(v.expr)
if (aggr == null) {
buf(i).put(v.expr, AggrDatapoint.newAggregator(v, aggrSettings))
} else {
aggr.aggregate(v)
}
}
}
Expand All @@ -120,15 +122,17 @@ private[stream] class TimeGrouped(
* Push the most recently completed time group to the next stage and reset the buffer
* so it can be used for a new time window.
*/
private def flush(i: Int): Unit = {
private def flush(i: Int): Option[TimeGroup] = {
val t = timestamps(i)
if (t > 0) push(out, toTimeGroup(t, buf(i))) else pull(in)
val group = if (t > 0) Some(toTimeGroup(t, buf(i))) else None
cutoffTime = t
buf(i) = new AggrMap
group
}

private def toTimeGroup(ts: Long, aggrMap: AggrMap): TimeGroup = {
val aggregateMapForExpWithinLimits = aggrMap
import scala.jdk.CollectionConverters._
val aggregateMapForExpWithinLimits = aggrMap.asScala
.filter {
case (expr, aggr) if aggr.limitExceeded =>
context.logDatapointsExceeded(ts, expr)
Expand All @@ -145,29 +149,33 @@ private[stream] class TimeGrouped(
}

override def onPush(): Unit = {
val v = grab(in)
val t = v.timestamp
val now = clock.wallTime()
step = v.step
if (t > now) {
droppedFutureUpdater.increment()
pull(in)
} else if (t <= cutoffTime) {
droppedOldUpdater.increment()
pull(in)
} else {
bufferedUpdater.increment()
val i = findBuffer(t)
if (i >= 0) {
aggregate(i, v)
pull(in)
val builder = List.newBuilder[TimeGroup]
grab(in).foreach { v =>
val t = v.timestamp
val now = clock.wallTime()
step = v.step
if (t > now) {
droppedFutureUpdater.increment()
} else if (t <= cutoffTime) {
droppedOldUpdater.increment()
} else {
val pos = -i - 1
flush(pos)
aggregate(pos, v)
timestamps(pos) = t
bufferedUpdater.increment()
val i = findBuffer(t)
if (i >= 0) {
aggregate(i, v)
} else {
val pos = -i - 1
builder ++= flush(pos)
aggregate(pos, v)
timestamps(pos) = t
}
}
}
val groups = builder.result()
if (groups.isEmpty)
pull(in)
else
push(out, groups)
}

override def onPull(): Unit = {
Expand All @@ -190,8 +198,8 @@ private[stream] class TimeGrouped(

private def flushPending(): Unit = {
if (pending.nonEmpty && isAvailable(out)) {
push(out, pending.head)
pending = pending.tail
push(out, pending)
pending = Nil
}
if (pending.isEmpty) {
completeStage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class LwcToAggrDatapointSuite extends FunSuite {
val future = Source(data)
.map(ByteString.apply)
.map(LwcMessages.parse)
.map(msg => List(msg))
.via(new LwcToAggrDatapoint(context))
.flatMapConcat(Source.apply)
.runWith(Sink.seq[AggrDatapoint])
Await.result(future, Duration.Inf).toList
}
Expand Down
Loading