Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import io.shiftleft.SerializedCpg
import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder}
import org.slf4j.{Logger, LoggerFactory, MDC}

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.function.{BiConsumer, Supplier}
import scala.annotation.nowarn
import scala.jdk.CollectionConverters.*
import scala.concurrent.duration.DurationLong
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -137,6 +139,62 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S

}

/** A [[ForkJoinParallelCpgPass]] that additionally maintains a thread-local accumulator of type [[R]] which is merged
* across all threads after processing completes. This enables map-reduce style aggregation alongside the usual
* DiffGraph-based graph modifications.
*
* Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the
* accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]].
*
* @tparam T
* the part type (same as in [[ForkJoinParallelCpgPass]])
* @tparam R
* the accumulator type
*/
abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, outName: String = "")
extends ForkJoinParallelCpgPass[T](cpg, outName) {

/** Create a fresh, empty accumulator. Called once per thread. */
protected def newAccumulator(): R

/** Merge two accumulators. Must be associative. The result may reuse either argument. */
protected def mergeAccumulators(left: R, right: R): R

/** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */
protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit

/** Called after all parts are processed with the fully merged accumulator. */
protected def onAccumulatorComplete(acc: R): Unit = {}

private val accumulators = new ConcurrentLinkedQueue[R]()

private val threadLocalAcc: ThreadLocal[R] = new ThreadLocal[R]()

final override def runOnPart(builder: DiffGraphBuilder, part: T): Unit = {
var acc = threadLocalAcc.get()
if (acc == null) {
acc = newAccumulator()
threadLocalAcc.set(acc)
accumulators.add(acc)
}
runOnPartWithAccumulator(builder, acc, part)
}

override def init(): Unit = {
accumulators.clear()
threadLocalAcc.remove()
super.init()
}

override def finish(): Unit = {
val merged = accumulators.asScala.reduceOption(mergeAccumulators).getOrElse(newAccumulator())
onAccumulatorComplete(merged)
accumulators.clear()
threadLocalAcc.remove()
super.finish()
}
}

trait CpgPassBase {

protected def baseLogger: Logger = LoggerFactory.getLogger(getClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,114 @@ class CpgPassNewTests extends AnyWordSpec with Matchers {
}
}

"ForkJoinParallelCpgPassWithAccumulator" should {
"merge accumulators and invoke completion callback once" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int]
override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] =
left ++= right
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = acc += part.length
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array("a", "bb", "ccc")
override def isParallel: Boolean = false
}

pass.createAndApply()

completed.toSeq shouldBe Seq(6)
}

"use a fresh accumulator when there are no parts" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42)
override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] =
left ++= right
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = ()
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array.empty
}

pass.createAndApply()

completed.toSeq shouldBe Seq(42)
}

"clear accumulator state between runs" in {
val cpg = Cpg.empty
val completed = ArrayBuffer.empty[Int]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] =
new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") {
override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int]
override protected def mergeAccumulators(
left: ArrayBuffer[Int],
right: ArrayBuffer[Int]
): ArrayBuffer[Int] = {
left ++= right
}
override protected def runOnPartWithAccumulator(
builder: DiffGraphBuilder,
acc: ArrayBuffer[Int],
part: String
): Unit = acc += part.toInt
override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum
override def generateParts(): Array[String] = Array("1", "2", "3")
override def isParallel: Boolean = false
}

pass.createAndApply()
pass.createAndApply()

completed.toSeq shouldBe Seq(6, 6)
}

"invoke completion callback once when a part fails" in {
val cpg = Cpg.empty
val events = ArrayBuffer.empty[String]

val pass: ForkJoinParallelCpgPassWithAccumulator[String, Int] =
new ForkJoinParallelCpgPassWithAccumulator[String, Int](cpg, "acc-fail") {
override protected def newAccumulator(): Int = 0
override protected def mergeAccumulators(left: Int, right: Int): Int = left + right
override protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: Int, part: String): Unit = {
events += "run"
throw new RuntimeException("boom")
}
override protected def onAccumulatorComplete(acc: Int): Unit = events += s"complete:$acc"
override def generateParts(): Array[String] = Array("p1")
override def isParallel: Boolean = false
override def init(): Unit = {
events += "init"
super.init()
}
override def finish(): Unit = {
events += "finish"
super.finish()
}
}

intercept[RuntimeException] {
pass.createAndApply()
}

events.toSeq shouldBe Seq("init", "run", "finish", "complete:0")
}
}

}
Loading