Fix OOM in template matching for large cluster counts#1027
Fix OOM in template matching for large cluster counts#1027alowet wants to merge 6 commits intoMouseLand:mainfrom
Conversation
When clustering produces many templates (e.g. 17K+), prepare_matching computes UtU and ctc tensors that scale as O(N^2) and can exceed GPU memory (30+ GiB for 17K templates). This adds an adaptive fallback that computes ctc columns on-the-fly in chunks of 128 when the full tensor would not fit, keeping peak memory around 1.4 GiB per chunk. The original precomputed path is preserved for smaller template counts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add gc.collect() + empty_cache() in detect_spikes (after clustering) and in extract() to reclaim reserved GPU memory before heavy allocations. Use in-place ops for Cf computation to avoid allocating multiple full-size temporary tensors, and delete Cf immediately after extracting max values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The tqdm bar previously wrapped only the outer xcent loop, which often has just 1 element (showing "1/1"). Now it tracks progress over all xcent * ycent centers, giving useful ETA estimates during the 15+ minute clustering step. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a batch has only 1 spike, the interleaved split iY[1::2] is empty, causing _compute_ctc_columns to return from an empty list. Skip the computation when the slice is empty. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Can you please also upload |
|
Certainly, here is one from Windows, although note that most of my testing has been on Ubuntu 24.04. |
|
Also happy to share the raw data if you prefer that. Like I said, there may be some motion artifacts, but I checked in the SpikeGadgets software and there's nothing very out of the ordinary. Please note that I only get the error on certain sessions; others seem to sort fine. |
|
Are you using the same settings for all sessions? One thing I would recommend changing from the log you attached is setting Regardless, getting the GPU computation you're referencing to proceed would not fix the issue. If you're getting ~20,000 clusters on a Neuropixels 1 probe, something else has already gone wrong. |
|
I have played with the settings a bit to see if it works. Some used versions where cluster_downsampling = 20, others used 1 (as in the attached log with v4.1.1). I will start explicitly setting it to 20. It must be something about the noise structure in the recording in that case. I will investigate and let you know if I find a fix. Thank you! |
With 14K+ templates and batch_size=131072, the full Cf tensor (clamp(B,0)^2/nm) requires ~7.25 GB. Instead of allocating it all at once, compute Cfmax/imax in chunks sized to fit 15% of free GPU memory. Also call empty_cache() before B allocation to reclaim reserved-but-unallocated PyTorch memory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- prepare_matching: ctc precomputation threshold now accounts for B tensor size (N * NTbuff * 4) rather than using arbitrary 40% of free memory - _compute_ctc_columns: chunk size now computed from free GPU memory and per-column byte cost (UtU_chunk + ctc_chunk) rather than hardcoded 128 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Just a note to say that the updated fork now fixes the issues I had by computing template-matching batchwise. Specifically:
The ctc tensor (shape N_templates × N_templates × (2*nt+1)) scales as O(N²) and requires ~108 GB for 15K templates. When it won't fit in 40% of free
Changed the "First clustering" tqdm bar to track all len(xcent) * len(ycent) spatial centers instead of just showing 0/1.
Added elif len(iY[j::n]) > 0: guard before calling _compute_ctc_columns — when a batch has only 1 spike, slicing iY[1::2] yields an empty tensor
Even after the ctc fix, torch.clamp(B, min=0) allocates a full copy of B (~7.25 GB for 14826 templates × 131072 time points). Replaced with chunked Tested on
Of course, this runs much more slowly whenever we have to deal with 17K+ clusters. However, the final, merged results yield a reasonable number of "good" units (~300) that are visually good in Phy. I now am testing tighter bounds on the available memory to see if I can speed up the computation slightly: prepare_matching threshold: Now computes ctc_bytes + B_bytes * 2 < free_mem * 0.8 — it accounts for B (the largest competing tensor, N × NTbuff × 4) _compute_ctc_columns chunk size: Now computed as 0.15 * free_mem / bytes_per_col where bytes_per_col = N * (n_pcs² + (2*nt+1)) * 4 — the actual Will confirm once this is tested. |
Summary
prepare_matchingtries to allocate O(N²) tensors (UtUandctc) that can exceed GPU memory. For example, 17,345 templates requires ~30 GiB forUtUalone, which OOMs on even an RTX 5090 (32 GiB).prepare_matchingnow checks whether the fullctctensor fits in available GPU memory (< 40% of free VRAM). If not, it returns onlyWtWandrun_matchingcomputes the neededctccolumns on-the-fly in chunks of 128, keeping peak memory around 1.4 GiB per chunk.Context
I've tested Kilosort 4.1.7, 4.1.0 and 4.0.37 on 384-channel Neuropixel 1.0 recordings acquired using the SpikeGadgets Bennu headstage. Some sessions produce 50M+ detected spikes and 17K+ clusters after the first clustering step, probably due to intermittent motion noise artifacts. The
UtU = torch.einsum('ikl, jml -> ijkm', U, U)line inprepare_matchingthen tries to allocate 30+ GiB, which OOMs on both RTX 3090 (24 GiB) and RTX 5090 (32 GiB). This happens regardless ofbatch_sizeor threshold settings.Changes
prepare_matching: Returns(ctc, None)when precomputed ctc fits, or(None, WtW)when it doesn't_compute_ctc_columns(new): Computesctc[:, sel_iY, :]on-the-fly in chunks of 128 templatesrun_matching: Accepts optionalWtWparameter; uses chunked computation whenctc is Noneextract: Updated to unpack the tuple and pass both torun_matchingTest plan
Full traceback for a failed run:
kilosort.run_kilosort: Kilosort version 4.1.7
kilosort.run_kilosort: Python version 3.13.5
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: System information:
kilosort.run_kilosort: Linux-6.14.0-37-generic-x86_64-with-glibc2.39 x86_64
kilosort.run_kilosort: x86_64
kilosort.run_kilosort: Using CUDA device: NVIDIA GeForce RTX 5090 31.33GB
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: Sorting [PosixPath('/mnt/s5adam/data/SocialForaging/date=2026-03-06/session=1/ephys/bat=11714/processed/11714_20260306_150617_merged.kilosort/11714_20260306_150617_merged.probe1.dat')]
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage before sorting
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 6.60 %
kilosort.run_kilosort: Mem used: 17.50 % | 21.86 GB
kilosort.run_kilosort: Mem avail: 103.31 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 2.00 %
kilosort.run_kilosort: GPU memory: 25.83 % | 8.09 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.00 % | 0.00 / 31.33 GB
kilosort.run_kilosort: Max alloc: 0.00 % | 0.00 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Computing preprocessing variables.
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: N samples: 209273490
kilosort.run_kilosort: N seconds: 6975.783
kilosort.run_kilosort: N batches: 1597
kilosort.run_kilosort: Preprocessing filters computed in 10.05s; total 10.18s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after preprocessing
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 4.90 %
kilosort.run_kilosort: Mem used: 18.00 % | 22.51 GB
kilosort.run_kilosort: Mem avail: 102.66 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 30.00 %
kilosort.run_kilosort: GPU memory: 37.83 % | 11.85 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 9.02 % | 2.83 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Computing drift correction.
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1532
kilosort.spikedetect: Detecting spikes...
100%|████████████████████████████| 1597/1597 [10:04<00:00, 2.64it/s]
kilosort.run_kilosort: drift computed in 665.43s; total 675.73s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after drift correction
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 5.90 %
kilosort.run_kilosort: Mem used: 20.40 % | 25.55 GB
kilosort.run_kilosort: Mem avail: 99.62 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 60.28 % | 18.89 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 23.45 % | 7.35 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.io :
kilosort.io : ========================================
kilosort.io : Saving drift-corrected copy of data to: /mnt/s5adam/data/SocialForaging/date=2026-03-06/session=1/ephys/bat=11714/processed/kilosort_outdir_probe1/temp_wh.dat...
kilosort.io : Writing batch 0/1597...
kilosort.io : Writing batch 100/1597...
kilosort.io : Writing batch 200/1597...
kilosort.io : Writing batch 300/1597...
kilosort.io : Writing batch 400/1597...
kilosort.io : Writing batch 500/1597...
kilosort.io : Writing batch 600/1597...
kilosort.io : Writing batch 700/1597...
kilosort.io : Writing batch 800/1597...
kilosort.io : Writing batch 900/1597...
kilosort.io : Writing batch 1000/1597...
kilosort.io : Writing batch 1100/1597...
kilosort.io : Writing batch 1200/1597...
kilosort.io : Writing batch 1300/1597...
kilosort.io : Writing batch 1400/1597...
kilosort.io : Writing batch 1500/1597...
kilosort.io : ========================================
kilosort.io : Copying finished.
kilosort.io :
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after saving preprocessing.
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 5.20 %
kilosort.run_kilosort: Mem used: 20.50 % | 25.64 GB
kilosort.run_kilosort: Mem avail: 99.53 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 59.77 % | 18.73 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 9.03 % | 2.83 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: Generating drift plots ...
kilosort.run_kilosort:
kilosort.run_kilosort: Extracting spikes using templates
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1532
kilosort.spikedetect: Detecting spikes...
100%|████████████████████████████| 1597/1597 [10:03<00:00, 2.65it/s]
kilosort.run_kilosort: 53723905 spikes extracted in 625.32s; total 2400.81s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after spike detect (univ)
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 0.00 %
kilosort.run_kilosort: Mem used: 37.40 % | 46.82 GB
kilosort.run_kilosort: Mem avail: 78.34 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 100.00 %
kilosort.run_kilosort: GPU memory: 59.84 % | 18.75 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 23.45 % | 7.35 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: First clustering
kilosort.run_kilosort: ----------------------------------------
100%|██████████████████████████████| 1/1 [1:25:50<00:00, 5150.49s/it]
kilosort.run_kilosort: 17345 clusters found, in 5211.94s; total 7612.78s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after first clustering
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 4.20 %
kilosort.run_kilosort: Mem used: 39.00 % | 48.80 GB
kilosort.run_kilosort: Mem avail: 76.37 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 88.41 % | 27.70 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.44 % | 0.14 / 31.33 GB
kilosort.run_kilosort: Max alloc: 63.36 % | 19.85 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Extracting spikes using cluster waveforms
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: Out of memory error, printing performance...
Traceback (most recent call last):
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/run_kilosort.py", line 302, in _sort
st,tF, Wall0, clu0 = detect_spikes(
~~~~~~~~~~~~~^
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
clear_cache=clear_cache, verbose=verbose_log
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/run_kilosort.py", line 842, in detect_spikes
st, tF, ops = template_matching.extract(
~~~~~~~~~~~~~~~~~~~~~~~~~^
ops, bfile, Wall3, device=device, progress_bar=progress_bar
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/template_matching.py", line 68, in extract
ctc = prepare_matching(ops, U)
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/template_matching.py", line 164, in prepare_matching
UtU = torch.einsum('ikl, jml -> ijkm', U, U)
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/torch/functional.py", line 422, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.02 GiB. GPU 0 has a total capacity of 31.33 GiB of which 21.80 GiB is free. Process 4058 has 3.00 GiB memory in use. Including non-PyTorch memory, this process has 2.67 GiB memory in use. Of the allocated memory 409.02 MiB is allocated by PyTorch, and 1.57 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 8.80 %
kilosort.run_kilosort: Mem used: 38.90 % | 48.73 GB
kilosort.run_kilosort: Mem avail: 76.44 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 30.44 % | 9.54 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.46 % | 0.14 / 31.33 GB
kilosort.run_kilosort: Max alloc: 1.27 % | 0.40 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: Encountered error in
run_kilosort:Parameters used in this run:
{'n_chan_bin': 384, 'fs': 30000, 'batch_size': 131072, 'nblocks': 5, 'Th_universal': 9, 'Th_learned': 8, 'tmin': 0, 'tmax': inf, 'nt': 61, 'shift': None, 'scale': None, 'batch_downsampling': 1, 'artifact_threshold': inf, 'nskip': 25, 'whitening_range': 32, 'highpass_cutoff': 300, 'binning_depth': 5, 'sig_interp': 20, 'drift_smoothing': [0.5, 0.5, 0.5], 'nt0min': None, 'dmin': None, 'dminx': 32, 'min_template_size': 10, 'template_sizes': 5, 'nearest_chans': 10, 'nearest_templates': 100, 'max_channel_distance': 32, 'max_peels': 100, 'templates_from_data': True, 'n_templates': 6, 'n_pcs': 6, 'Th_single_ch': 6, 'acg_threshold': 0.2, 'ccg_threshold': 0.25, 'cluster_neighbors': 10, 'cluster_downsampling': 20, 'max_cluster_subset': 25000, 'x_centers': None, 'cluster_init_seed': 5, 'duplicate_spike_ms': 0.25, 'position_limit': 100}