diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index a6e4932cb1..d6a89dab22 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -292,6 +292,39 @@ def remove_cases(self, case_keys): self.remove_benchmark(key) (self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases)) + def set_precomputed_results(self, precomputed_results, verbose=False): + """Set precomputed results for some cases. This is useful when you want to compute results outside of the benchmark and + then set them in the benchmark. + + Parameters + ---------- + precomputed_results : dict + A dict with the same keys as cases and values are dict with the results to set for each case. + The keys of the inner dict must be the same as the keys of the benchmark result. + 'run_time' is a special key that will be set to 0.0 if not present in the precomputed results. + verbose : bool, default: False + Whether to print the keys of the precomputed results when setting them. + """ + + for key in precomputed_results.keys(): + assert key in self.cases, f"Key {key} in precomputed_results is not in cases" + benchmark = self.create_benchmark(key) + if verbose: + print("### Set benchmark", key, "###") + + for k, v in precomputed_results[key].items(): + benchmark.result[k] = v + if "run_time" not in benchmark.result: + benchmark.result["run_time"] = 0.0 + if verbose: + print(f"Warning: 'run_time' is not in the precomputed results for key {key}, setting it to 0.0") + + self.benchmarks[key] = benchmark + bench_folder = self.folder / "results" / self.key_to_str(key) + bench_folder.mkdir(exist_ok=True) + benchmark.save_run(bench_folder) + benchmark.save_main(bench_folder) + def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs): if case_keys is None: case_keys = list(self.cases.keys())