diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index f43571e44..a08d019b2 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -5,8 +5,10 @@ import io.shiftleft.SerializedCpg import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import org.slf4j.{Logger, LoggerFactory, MDC} +import java.util.concurrent.ConcurrentLinkedQueue import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn +import scala.jdk.CollectionConverters.* import scala.concurrent.duration.DurationLong import scala.util.{Failure, Success, Try} @@ -137,6 +139,62 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S } +/** A [[ForkJoinParallelCpgPass]] that additionally maintains a thread-local accumulator of type [[R]] which is merged + * across all threads after processing completes. This enables map-reduce style aggregation alongside the usual + * DiffGraph-based graph modifications. + * + * Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the + * accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]]. + * + * @tparam T + * the part type (same as in [[ForkJoinParallelCpgPass]]) + * @tparam R + * the accumulator type + */ +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, outName: String = "") + extends ForkJoinParallelCpgPass[T](cpg, outName) { + + /** Create a fresh, empty accumulator. Called once per thread. */ + protected def newAccumulator(): R + + /** Merge two accumulators. Must be associative. The result may reuse either argument. */ + protected def mergeAccumulators(left: R, right: R): R + + /** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */ + protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit + + /** Called after all parts are processed with the fully merged accumulator. */ + protected def onAccumulatorComplete(acc: R): Unit = {} + + private val accumulators = new ConcurrentLinkedQueue[R]() + + private val threadLocalAcc: ThreadLocal[R] = new ThreadLocal[R]() + + final override def runOnPart(builder: DiffGraphBuilder, part: T): Unit = { + var acc = threadLocalAcc.get() + if (acc == null) { + acc = newAccumulator() + threadLocalAcc.set(acc) + accumulators.add(acc) + } + runOnPartWithAccumulator(builder, acc, part) + } + + override def init(): Unit = { + accumulators.clear() + threadLocalAcc.remove() + super.init() + } + + override def finish(): Unit = { + val merged = accumulators.asScala.reduceOption(mergeAccumulators).getOrElse(newAccumulator()) + onAccumulatorComplete(merged) + accumulators.clear() + threadLocalAcc.remove() + super.finish() + } +} + trait CpgPassBase { protected def baseLogger: Logger = LoggerFactory.getLogger(getClass) diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index dc32d7975..467e786ba 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -90,4 +90,114 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { } } + "ForkJoinParallelCpgPassWithAccumulator" should { + "merge accumulators and invoke completion callback once" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.length + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("a", "bb", "ccc") + override def isParallel: Boolean = false + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(6) + } + + "use a fresh accumulator when there are no parts" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42) + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = () + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array.empty + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(42) + } + + "clear accumulator state between runs" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators( + left: ArrayBuffer[Int], + right: ArrayBuffer[Int] + ): ArrayBuffer[Int] = { + left ++= right + } + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.toInt + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("1", "2", "3") + override def isParallel: Boolean = false + } + + pass.createAndApply() + pass.createAndApply() + + completed.toSeq shouldBe Seq(6, 6) + } + + "invoke completion callback once when a part fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, Int] = + new ForkJoinParallelCpgPassWithAccumulator[String, Int](cpg, "acc-fail") { + override protected def newAccumulator(): Int = 0 + override protected def mergeAccumulators(left: Int, right: Int): Int = left + right + override protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: Int, part: String): Unit = { + events += "run" + throw new RuntimeException("boom") + } + override protected def onAccumulatorComplete(acc: Int): Unit = events += s"complete:$acc" + override def generateParts(): Array[String] = Array("p1") + override def isParallel: Boolean = false + override def init(): Unit = { + events += "init" + super.init() + } + override def finish(): Unit = { + events += "finish" + super.finish() + } + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + events.toSeq shouldBe Seq("init", "run", "finish", "complete:0") + } + } + }