From 7e7f19a7961d9d6d1f516bbaa771c86536e8ec56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:17:39 +0100 Subject: [PATCH 1/2] Ensure finish() is called exactly once in ForkJoinParallelCpgPass lifecycle --- .../scala/io/shiftleft/passes/CpgPass.scala | 51 ++++++------------- .../io/shiftleft/passes/CpgPassNewTests.scala | 42 +++++++++++++-- 2 files changed, 54 insertions(+), 39 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index 58509d6f3..e6fd5e2a7 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -78,20 +78,14 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph) } catch { case exc: Exception => - baseLogger.error(s"Pass ${name} failed", exc) + baseLogger.error(s"Pass $name failed", exc) throw exc } finally { - try { - finish() - } finally { - // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() - // in the reported timings, and we must have our final log message if finish() throws - val nanosStop = System.nanoTime() - val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) - baseLogger.info( - f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts." - ) - } + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts." + ) } } @@ -106,27 +100,12 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S runOnPart(externalBuilder, parts(0).asInstanceOf[T]) case _ => val stream = - if (!isParallel) - java.util.Arrays - .stream(parts) - .sequential() - else - java.util.Arrays - .stream(parts) - .parallel() + if (!isParallel) java.util.Arrays.stream(parts).sequential() + else java.util.Arrays.stream(parts).parallel() val diff = stream.collect( - new Supplier[DiffGraphBuilder] { - override def get(): DiffGraphBuilder = - Cpg.newDiffGraphBuilder - }, - new BiConsumer[DiffGraphBuilder, AnyRef] { - override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = - runOnPart(builder, part.asInstanceOf[T]) - }, - new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { - override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = - leftBuilder.absorb(rightBuilder) - } + () => Cpg.newDiffGraphBuilder, + (builder: DiffGraphBuilder, part: AnyRef) => runOnPart(builder, part.asInstanceOf[T]), + (leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder) => leftBuilder.absorb(rightBuilder) ) externalBuilder.absorb(diff) } @@ -152,12 +131,12 @@ trait CpgPassBase { @deprecated("Please use createAndApply") def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit - /** Name of the pass. By default it is inferred from the name of the class, override if needed. + /** Name of the pass. By default, it is inferred from the name of the class, override if needed. */ def name: String = getClass.getName /** Runs the cpg pass, adding changes to the passed builder. Use with caution -- API is unstable. Returns max(nParts, - * 1), where nParts is either the number of parallel parts, or the number of iterarator elements in case of legacy + * 1), where nParts is either the number of parallel parts, or the number of iterator elements in case of legacy * passes. Includes init() and finish() logic. */ def runWithBuilder(builder: DiffGraphBuilder): Int @@ -172,11 +151,11 @@ trait CpgPassBase { Try(runWithBuilder(builder)) match { case Success(nParts) => baseLogger.info( - f"Pass ${name} completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from ${nParts}%d parts." + f"Pass $name completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from $nParts%d parts." ) nParts case Failure(exception) => - baseLogger.warn(f"Pass ${name} failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception) + baseLogger.warn(f"Pass $name failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception) -1 } } diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index c5486de96..dc32d7975 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -1,14 +1,13 @@ package io.shiftleft.passes -import better.files.File import flatgraph.SchemaViolationException import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.NewFile import io.shiftleft.codepropertygraph.generated.language.* +import io.shiftleft.codepropertygraph.generated.nodes.NewFile import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import java.nio.file.Files +import scala.collection.mutable.ArrayBuffer class CpgPassNewTests extends AnyWordSpec with Matchers { @@ -52,6 +51,43 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { pass.createAndApply() } } + + "call init and finish once around run" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "lifecycle-pass") { + override def init(): Unit = events += "init" + override def generateParts(): Array[String] = Array("p1") + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = events += "run" + override def finish(): Unit = events += "finish" + } + + pass.createAndApply() + + // all events should be in the expected order and should only occur once + events.toSeq shouldBe Seq("init", "run", "finish") + } + + "call finish once when run fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "failing-lifecycle-pass") { + override def init(): Unit = events += "init" + override def generateParts(): Array[String] = Array("p1") + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = { + events += "run" + throw new RuntimeException("run failed") + } + override def finish(): Unit = events += "finish" + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + // all events should be in the expected order and should only occur once even if run fails + events.toSeq shouldBe Seq("init", "run", "finish") + } } } From 82a6c458e2b5749eadb6fe992be9fa789e6a2084 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20Leuth=C3=A4user?= <1417198+max-leuthaeuser@users.noreply.github.com> Date: Wed, 11 Mar 2026 16:09:41 +0100 Subject: [PATCH 2/2] revert other changes --- .../scala/io/shiftleft/passes/CpgPass.scala | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index e6fd5e2a7..f43571e44 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -78,7 +78,7 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph) } catch { case exc: Exception => - baseLogger.error(s"Pass $name failed", exc) + baseLogger.error(s"Pass ${name} failed", exc) throw exc } finally { val nanosStop = System.nanoTime() @@ -100,12 +100,27 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S runOnPart(externalBuilder, parts(0).asInstanceOf[T]) case _ => val stream = - if (!isParallel) java.util.Arrays.stream(parts).sequential() - else java.util.Arrays.stream(parts).parallel() + if (!isParallel) + java.util.Arrays + .stream(parts) + .sequential() + else + java.util.Arrays + .stream(parts) + .parallel() val diff = stream.collect( - () => Cpg.newDiffGraphBuilder, - (builder: DiffGraphBuilder, part: AnyRef) => runOnPart(builder, part.asInstanceOf[T]), - (leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder) => leftBuilder.absorb(rightBuilder) + new Supplier[DiffGraphBuilder] { + override def get(): DiffGraphBuilder = + Cpg.newDiffGraphBuilder + }, + new BiConsumer[DiffGraphBuilder, AnyRef] { + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = + runOnPart(builder, part.asInstanceOf[T]) + }, + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = + leftBuilder.absorb(rightBuilder) + } ) externalBuilder.absorb(diff) } @@ -131,12 +146,12 @@ trait CpgPassBase { @deprecated("Please use createAndApply") def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit - /** Name of the pass. By default, it is inferred from the name of the class, override if needed. + /** Name of the pass. By default it is inferred from the name of the class, override if needed. */ def name: String = getClass.getName /** Runs the cpg pass, adding changes to the passed builder. Use with caution -- API is unstable. Returns max(nParts, - * 1), where nParts is either the number of parallel parts, or the number of iterator elements in case of legacy + * 1), where nParts is either the number of parallel parts, or the number of iterarator elements in case of legacy * passes. Includes init() and finish() logic. */ def runWithBuilder(builder: DiffGraphBuilder): Int @@ -151,11 +166,11 @@ trait CpgPassBase { Try(runWithBuilder(builder)) match { case Success(nParts) => baseLogger.info( - f"Pass $name completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from $nParts%d parts." + f"Pass ${name} completed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms. ${builder.size - size0}%d changes generated from ${nParts}%d parts." ) nParts case Failure(exception) => - baseLogger.warn(f"Pass $name failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception) + baseLogger.warn(f"Pass ${name} failed in ${(System.nanoTime() - nanoStart) * 1e-6}%.0f ms", exception) -1 } }