diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index f43571e44..e76d09076 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -49,7 +49,20 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct * passes eagerly, and releases them only when the entire chain has run. * */ -abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: String = "") extends CpgPassBase { +abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outname: String = "") + extends ForkJoinParallelCpgPassWithAccumulator[T, Null](cpg, outname) { + def runOnPart(builder: DiffGraphBuilder, part: T): Unit + + override def createAccumulator(): Null = null + override def runOnPart(builder: DiffGraphBuilder, part: T, acc: Null): Unit = runOnPart(builder, part) + override def runOnFinalAccumulator(builder: DiffGraphBuilder, accumulator: Null): Unit = {} + override def mergeAccumulator(left: Null, accumulator: Null): Unit = {} +} + +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, Accumulator <: AnyRef]( + cpg: Cpg, + @nowarn outName: String = "" +) extends CpgPassBase { type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder // generate Array of parts that can be processed in parallel def generateParts(): Array[? <: AnyRef] @@ -58,10 +71,14 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S // release large data structures and external resources def finish(): Unit = {} // main function: add desired changes to builder - def runOnPart(builder: DiffGraphBuilder, part: T): Unit + def runOnPart(builder: DiffGraphBuilder, part: T, accumulator: Accumulator): Unit // Override this to disable parallelism of passes. Useful for debugging. def isParallel: Boolean = true + def createAccumulator(): Accumulator + def mergeAccumulator(left: Accumulator, accumulator: Accumulator): Unit + def runOnFinalAccumulator(builder: DiffGraphBuilder, accumulator: Accumulator): Unit + override def createAndApply(): Unit = { baseLogger.info(s"Start of pass: $name") val nanosStart = System.nanoTime() @@ -95,9 +112,6 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S val parts = generateParts() val nParts = parts.size nParts match { - case 0 => - case 1 => - runOnPart(externalBuilder, parts(0).asInstanceOf[T]) case _ => val stream = if (!isParallel) @@ -108,20 +122,26 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S java.util.Arrays .stream(parts) .parallel() - val diff = stream.collect( - new Supplier[DiffGraphBuilder] { - override def get(): DiffGraphBuilder = - Cpg.newDiffGraphBuilder + val (diff, acc) = stream.collect( + new Supplier[(DiffGraphBuilder, Accumulator)] { + override def get(): (DiffGraphBuilder, Accumulator) = + (Cpg.newDiffGraphBuilder, createAccumulator()) }, - new BiConsumer[DiffGraphBuilder, AnyRef] { - override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = - runOnPart(builder, part.asInstanceOf[T]) + new BiConsumer[(DiffGraphBuilder, Accumulator), AnyRef] { + override def accept(builder: (DiffGraphBuilder, Accumulator), part: AnyRef): Unit = + runOnPart(builder._1, part.asInstanceOf[T], builder._2) }, - new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { - override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = - leftBuilder.absorb(rightBuilder) + new BiConsumer[(DiffGraphBuilder, Accumulator), (DiffGraphBuilder, Accumulator)] { + override def accept( + leftBuilder: (DiffGraphBuilder, Accumulator), + rightBuilder: (DiffGraphBuilder, Accumulator) + ): Unit = { + leftBuilder._1.absorb(rightBuilder._1) + mergeAccumulator(leftBuilder._2, rightBuilder._2) + } } ) + runOnFinalAccumulator(diff, acc) externalBuilder.absorb(diff) } nParts