diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index 58509d6f3..f43571e44 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -81,17 +81,11 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S baseLogger.error(s"Pass ${name} failed", exc) throw exc } finally { - try { - finish() - } finally { - // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() - // in the reported timings, and we must have our final log message if finish() throws - val nanosStop = System.nanoTime() - val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) - baseLogger.info( - f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts." - ) - } + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts." + ) } } diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index c5486de96..dc32d7975 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -1,14 +1,13 @@ package io.shiftleft.passes -import better.files.File import flatgraph.SchemaViolationException import io.shiftleft.codepropertygraph.generated.Cpg -import io.shiftleft.codepropertygraph.generated.nodes.NewFile import io.shiftleft.codepropertygraph.generated.language.* +import io.shiftleft.codepropertygraph.generated.nodes.NewFile import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import java.nio.file.Files +import scala.collection.mutable.ArrayBuffer class CpgPassNewTests extends AnyWordSpec with Matchers { @@ -52,6 +51,43 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { pass.createAndApply() } } + + "call init and finish once around run" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "lifecycle-pass") { + override def init(): Unit = events += "init" + override def generateParts(): Array[String] = Array("p1") + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = events += "run" + override def finish(): Unit = events += "finish" + } + + pass.createAndApply() + + // all events should be in the expected order and should only occur once + events.toSeq shouldBe Seq("init", "run", "finish") + } + + "call finish once when run fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + val pass: ForkJoinParallelCpgPass[String] = new ForkJoinParallelCpgPass[String](cpg, "failing-lifecycle-pass") { + override def init(): Unit = events += "init" + override def generateParts(): Array[String] = Array("p1") + override def runOnPart(builder: DiffGraphBuilder, part: String): Unit = { + events += "run" + throw new RuntimeException("run failed") + } + override def finish(): Unit = events += "finish" + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + // all events should be in the expected order and should only occur once even if run fails + events.toSeq shouldBe Seq("init", "run", "finish") + } } }