Quality Metrics Tutorial

After spike sorting, you might want to validate the ‘goodness’ of the sorted units. This can be done using the qualitymetrics submodule, which computes several quality metrics of the sorted units.

import spikeinterface.core as si
from spikeinterface.metrics import (
    compute_snrs,
    compute_presence_ratios,
    compute_isi_violations,
)

First, let’s generate a simulated recording and sorting

recording, sorting = si.generate_ground_truth_recording()
print(recording)
print(sorting)
GroundTruthRecording (InjectTemplatesRecording): 4 channels - 25.0kHz - 1 segments
                      250,000 samples - 10.00s - float32 dtype - 3.81 MiB
GroundTruthSorting (NumpySorting): 10 units - 1 segments - 25.0kHz

Create SortingAnalyzer

For quality metrics we need first to create a SortingAnalyzer.

analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory")
print(analyzer)
estimate_sparsity (no parallelization):   0%|          | 0/10 [00:00<?, ?it/s]
estimate_sparsity (no parallelization): 100%|██████████| 10/10 [00:00<00:00, 418.83it/s]
SortingAnalyzer: 4 channels - 10 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions

Depending on which metrics we want to compute we will need first to compute some necessary extensions. (if not computed an error message will be raised)

analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205)
analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=2)
analyzer.compute("templates", operators=["average", "median", "std"])
analyzer.compute("noise_levels")

print(analyzer)
compute_waveforms (workers: 2 processes):   0%|          | 0/10 [00:00<?, ?it/s]
compute_waveforms (workers: 2 processes):  50%|█████     | 5/10 [00:00<00:00, 49.74it/s]
compute_waveforms (workers: 2 processes): 100%|██████████| 10/10 [00:00<00:00, 80.26it/s]

noise_level (no parallelization):   0%|          | 0/20 [00:00<?, ?it/s]
noise_level (no parallelization): 100%|██████████| 20/20 [00:00<00:00, 272.54it/s]
SortingAnalyzer: 4 channels - 10 units - 1 segments - memory - sparse - has recording
Loaded 4 extensions: random_spikes, waveforms, templates, noise_levels

The spikeinterface.qualitymetrics submodule has a set of functions that allow users to compute metrics in a compact and easy way. To compute a single metric, one can simply run one of the quality metric functions as shown below. Each function has a variety of adjustable parameters that can be tuned.

presence_ratios = compute_presence_ratios(analyzer)
print(presence_ratios)
isi_violation_ratio, isi_violations_count = compute_isi_violations(analyzer)
print(isi_violation_ratio)
snrs = compute_snrs(analyzer)
print(snrs)
{np.str_('0'): nan, np.str_('1'): nan, np.str_('2'): nan, np.str_('3'): nan, np.str_('4'): nan, np.str_('5'): nan, np.str_('6'): nan, np.str_('7'): nan, np.str_('8'): nan, np.str_('9'): nan}
{np.str_('0'): np.float64(0.0), np.str_('1'): np.float64(0.0), np.str_('2'): np.float64(0.0), np.str_('3'): np.float64(0.0), np.str_('4'): np.float64(0.0), np.str_('5'): np.float64(0.0), np.str_('6'): np.float64(0.0), np.str_('7'): np.float64(0.0), np.str_('8'): np.float64(0.0), np.str_('9'): np.float64(0.0)}
{np.str_('0'): np.float64(25.83801378827219), np.str_('1'): np.float64(19.555919854107906), np.str_('2'): np.float64(13.943703950433207), np.str_('3'): np.float64(43.10769319900704), np.str_('4'): np.float64(24.33971871834926), np.str_('5'): np.float64(17.08276961713479), np.str_('6'): np.float64(5.984988872295748), np.str_('7'): np.float64(22.519068265389812), np.str_('8'): np.float64(13.846240555575847), np.str_('9'): np.float64(21.1011300531286)}

To compute more than one metric at once, we can use the SortingAnalyzer.compute("quality_metrics") function and indicate which metrics we want to compute. Then we can retrieve the results using the get_data() method as a pandas.DataFrame.

metrics_ext = analyzer.compute(
    "quality_metrics",
    metric_names=["presence_ratio", "snr", "amplitude_cutoff"],
    metric_params={
        "presence_ratio": {"bin_duration_s": 2.0},
    }
)
metrics = metrics_ext.get_data()
print(metrics)
   presence_ratio        snr
0             1.0  25.838014
1             1.0  19.555920
2             1.0  13.943704
3             1.0  43.107693
4             1.0  24.339719
5             1.0  17.082770
6             1.0   5.984989
7             1.0  22.519068
8             1.0  13.846241
9             1.0  21.101130

Some metrics are based on the principal component scores, so the extension must be computed before. For instance:

analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True)

metrics_ext = analyzer.compute(
    "quality_metrics",
    metric_names=[
        "mahalanobis",
        "d_prime",
    ],
)
metrics = metrics_ext.get_data()
print(metrics)
Fitting PCA:   0%|          | 0/10 [00:00<?, ?it/s]
Fitting PCA: 100%|██████████| 10/10 [00:00<00:00, 189.57it/s]

Projecting waveforms:   0%|          | 0/10 [00:00<?, ?it/s]
Projecting waveforms: 100%|██████████| 10/10 [00:00<00:00, 2416.07it/s]
   isolation_distance       l_ratio    d_prime        snr  presence_ratio
0          824.066487  1.606902e-17  11.735882  25.838014             1.0
1           42.675790  9.380239e-03   2.826587  19.555920             1.0
2           91.312185  5.309548e-07   4.070065  13.943704             1.0
3         1141.004760  1.511148e-10  10.151726  43.107693             1.0
4          133.573906  2.592162e-05   4.421751  24.339719             1.0
5           40.870338  3.724572e-03   1.972816  17.082770             1.0
6           46.060507  1.378295e-01   3.044455   5.984989             1.0
7           79.077168  1.280450e-03   2.868886  22.519068             1.0
8           15.857721  4.136081e-01   0.914812  13.846241             1.0
9           49.301625  1.632594e-04   3.562738  21.101130             1.0

Total running time of the script: (0 minutes 0.381 seconds)

Gallery generated by Sphinx-Gallery