diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index 6d34264eac..3ba057d059 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -34,8 +34,8 @@ def test_user_defined_periods(self): periods[idx]["unit_index"] = unit_index period_start = num_samples // 4 period_duration = num_samples // 2 - periods[idx]["start_sample_index"] = period_start - periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["start_sample_index"] = period_start - unit_index * 10 + periods[idx]["end_sample_index"] = period_start + period_duration + unit_index * 10 periods[idx]["segment_index"] = segment_index sorting_analyzer = self._prepare_sorting_analyzer( @@ -48,8 +48,17 @@ def test_user_defined_periods(self): minimum_valid_period_duration=1, ) # check that valid periods correspond to user defined periods - ext_periods = ext.get_data(outputs="numpy") - np.testing.assert_array_equal(ext_periods, periods) + ext_periods_numpy = ext.get_data(outputs="numpy") + np.testing.assert_array_equal(ext_periods_numpy, periods) + + # check that `numpy` and `by_unit` outputs are the same + ext_periods_by_unit = ext.get_data(outputs="by_unit") + for segment_index in range(num_segments): + for unit_index, unit_id in enumerate(unit_ids): + periods_numpy_seg0 = ext_periods_numpy[ext_periods_numpy["segment_index"] == segment_index] + periods_numpy_unit = periods_numpy_seg0[periods_numpy_seg0["unit_index"] == unit_index] + period = [(periods_numpy_unit["start_sample_index"][0], periods_numpy_unit["end_sample_index"][0])] + assert period == ext_periods_by_unit[segment_index][unit_id] def test_user_defined_periods_as_arrays(self): unit_ids = self.sorting.unit_ids diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 020396e38d..a76cd703aa 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -548,12 +548,12 @@ def _get_data(self, outputs: str = "by_unit"): for segment_index in range(self.sorting_analyzer.get_num_segments()): segment_mask = good_periods_array["segment_index"] == segment_index periods_dict = {} - for unit_index in unit_ids: - periods_dict[unit_index] = [] + for unit_index, unit_id in enumerate(unit_ids): + periods_dict[unit_id] = [] unit_mask = good_periods_array["unit_index"] == unit_index good_periods_unit_segment = good_periods_array[segment_mask & unit_mask] for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]: - periods_dict[unit_index].append((start, end)) + periods_dict[unit_id].append((start, end)) good_periods.append(periods_dict) return good_periods