Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,20 @@ abstract class CpgPass(cpg: Cpg, outName: String = "") extends ForkJoinParallelC
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
* passes eagerly, and releases them only when the entire chain has run.
* */
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: String = "") extends CpgPassBase {
abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outname: String = "")
extends ForkJoinParallelCpgPassWithAccumulator[T, Null](cpg, outname) {
def runOnPart(builder: DiffGraphBuilder, part: T): Unit

override def createAccumulator(): Null = null
override def runOnPart(builder: DiffGraphBuilder, part: T, acc: Null): Unit = runOnPart(builder, part)
override def runOnFinalAccumulator(builder: DiffGraphBuilder, accumulator: Null): Unit = {}
override def mergeAccumulator(left: Null, accumulator: Null): Unit = {}
}

abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, Accumulator <: AnyRef](
cpg: Cpg,
@nowarn outName: String = ""
) extends CpgPassBase {
type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder
// generate Array of parts that can be processed in parallel
def generateParts(): Array[? <: AnyRef]
Expand All @@ -58,10 +71,14 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
// release large data structures and external resources
def finish(): Unit = {}
// main function: add desired changes to builder
def runOnPart(builder: DiffGraphBuilder, part: T): Unit
def runOnPart(builder: DiffGraphBuilder, part: T, accumulator: Accumulator): Unit
// Override this to disable parallelism of passes. Useful for debugging.
def isParallel: Boolean = true

def createAccumulator(): Accumulator
def mergeAccumulator(left: Accumulator, accumulator: Accumulator): Unit
def runOnFinalAccumulator(builder: DiffGraphBuilder, accumulator: Accumulator): Unit

override def createAndApply(): Unit = {
baseLogger.info(s"Start of pass: $name")
val nanosStart = System.nanoTime()
Expand Down Expand Up @@ -95,9 +112,6 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
val parts = generateParts()
val nParts = parts.size
nParts match {
case 0 =>
case 1 =>
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
case _ =>
val stream =
if (!isParallel)
Expand All @@ -108,20 +122,26 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S
java.util.Arrays
.stream(parts)
.parallel()
val diff = stream.collect(
new Supplier[DiffGraphBuilder] {
override def get(): DiffGraphBuilder =
Cpg.newDiffGraphBuilder
val (diff, acc) = stream.collect(
new Supplier[(DiffGraphBuilder, Accumulator)] {
override def get(): (DiffGraphBuilder, Accumulator) =
(Cpg.newDiffGraphBuilder, createAccumulator())
},
new BiConsumer[DiffGraphBuilder, AnyRef] {
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
runOnPart(builder, part.asInstanceOf[T])
new BiConsumer[(DiffGraphBuilder, Accumulator), AnyRef] {
override def accept(builder: (DiffGraphBuilder, Accumulator), part: AnyRef): Unit =
runOnPart(builder._1, part.asInstanceOf[T], builder._2)
},
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
leftBuilder.absorb(rightBuilder)
new BiConsumer[(DiffGraphBuilder, Accumulator), (DiffGraphBuilder, Accumulator)] {
override def accept(
leftBuilder: (DiffGraphBuilder, Accumulator),
rightBuilder: (DiffGraphBuilder, Accumulator)
): Unit = {
leftBuilder._1.absorb(rightBuilder._1)
mergeAccumulator(leftBuilder._2, rightBuilder._2)
}
}
)
runOnFinalAccumulator(diff, acc)
externalBuilder.absorb(diff)
}
nParts
Expand Down
Loading