Skip to content

Fix OOM in template matching for large cluster counts#1027

Open
alowet wants to merge 6 commits intoMouseLand:mainfrom
alowet:fix-oom-template-matching
Open

Fix OOM in template matching for large cluster counts#1027
alowet wants to merge 6 commits intoMouseLand:mainfrom
alowet:fix-oom-template-matching

Conversation

@alowet
Copy link
Copy Markdown

@alowet alowet commented Mar 9, 2026

Summary

  • When first clustering produces a large number of templates (e.g. 17K+ from high-spike-count recordings), prepare_matching tries to allocate O(N²) tensors (UtU and ctc) that can exceed GPU memory. For example, 17,345 templates requires ~30 GiB for UtU alone, which OOMs on even an RTX 5090 (32 GiB).
  • This PR adds an adaptive fallback: prepare_matching now checks whether the full ctc tensor fits in available GPU memory (< 40% of free VRAM). If not, it returns only WtW and run_matching computes the needed ctc columns on-the-fly in chunks of 128, keeping peak memory around 1.4 GiB per chunk.
  • The original precomputed path is preserved for smaller template counts, so there is no performance impact for typical recordings.

Context

I've tested Kilosort 4.1.7, 4.1.0 and 4.0.37 on 384-channel Neuropixel 1.0 recordings acquired using the SpikeGadgets Bennu headstage. Some sessions produce 50M+ detected spikes and 17K+ clusters after the first clustering step, probably due to intermittent motion noise artifacts. The UtU = torch.einsum('ikl, jml -> ijkm', U, U) line in prepare_matching then tries to allocate 30+ GiB, which OOMs on both RTX 3090 (24 GiB) and RTX 5090 (32 GiB). This happens regardless of batch_size or threshold settings.

Changes

  • prepare_matching: Returns (ctc, None) when precomputed ctc fits, or (None, WtW) when it doesn't
  • _compute_ctc_columns (new): Computes ctc[:, sel_iY, :] on-the-fly in chunks of 128 templates
  • run_matching: Accepts optional WtW parameter; uses chunked computation when ctc is None
  • extract: Updated to unpack the tuple and pass both to run_matching

Test plan

  • Verify no regression on standard recordings (small template count → precomputed path)
  • Verify high-spike-count recordings (17K+ templates) no longer OOM
  • Verify spike sorting results are numerically identical for both paths (chunked computation produces the same ctc values)
  • pytest --gpu --runslow shows 30 tests passed

Full traceback for a failed run:

kilosort.run_kilosort: Kilosort version 4.1.7
kilosort.run_kilosort: Python version 3.13.5
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: System information:
kilosort.run_kilosort: Linux-6.14.0-37-generic-x86_64-with-glibc2.39 x86_64
kilosort.run_kilosort: x86_64
kilosort.run_kilosort: Using CUDA device: NVIDIA GeForce RTX 5090 31.33GB
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: Sorting [PosixPath('/mnt/s5adam/data/SocialForaging/date=2026-03-06/session=1/ephys/bat=11714/processed/11714_20260306_150617_merged.kilosort/11714_20260306_150617_merged.probe1.dat')]
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage before sorting
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 6.60 %
kilosort.run_kilosort: Mem used: 17.50 % | 21.86 GB
kilosort.run_kilosort: Mem avail: 103.31 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 2.00 %
kilosort.run_kilosort: GPU memory: 25.83 % | 8.09 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.00 % | 0.00 / 31.33 GB
kilosort.run_kilosort: Max alloc: 0.00 % | 0.00 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Computing preprocessing variables.
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: N samples: 209273490
kilosort.run_kilosort: N seconds: 6975.783
kilosort.run_kilosort: N batches: 1597
kilosort.run_kilosort: Preprocessing filters computed in 10.05s; total 10.18s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after preprocessing
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 4.90 %
kilosort.run_kilosort: Mem used: 18.00 % | 22.51 GB
kilosort.run_kilosort: Mem avail: 102.66 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 30.00 %
kilosort.run_kilosort: GPU memory: 37.83 % | 11.85 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 9.02 % | 2.83 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Computing drift correction.
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1532
kilosort.spikedetect: Detecting spikes...
100%|████████████████████████████| 1597/1597 [10:04<00:00, 2.64it/s]
kilosort.run_kilosort: drift computed in 665.43s; total 675.73s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after drift correction
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 5.90 %
kilosort.run_kilosort: Mem used: 20.40 % | 25.55 GB
kilosort.run_kilosort: Mem avail: 99.62 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 60.28 % | 18.89 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 23.45 % | 7.35 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.io :
kilosort.io : ========================================
kilosort.io : Saving drift-corrected copy of data to: /mnt/s5adam/data/SocialForaging/date=2026-03-06/session=1/ephys/bat=11714/processed/kilosort_outdir_probe1/temp_wh.dat...
kilosort.io : Writing batch 0/1597...
kilosort.io : Writing batch 100/1597...
kilosort.io : Writing batch 200/1597...
kilosort.io : Writing batch 300/1597...
kilosort.io : Writing batch 400/1597...
kilosort.io : Writing batch 500/1597...
kilosort.io : Writing batch 600/1597...
kilosort.io : Writing batch 700/1597...
kilosort.io : Writing batch 800/1597...
kilosort.io : Writing batch 900/1597...
kilosort.io : Writing batch 1000/1597...
kilosort.io : Writing batch 1100/1597...
kilosort.io : Writing batch 1200/1597...
kilosort.io : Writing batch 1300/1597...
kilosort.io : Writing batch 1400/1597...
kilosort.io : Writing batch 1500/1597...
kilosort.io : ========================================
kilosort.io : Copying finished.
kilosort.io :
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after saving preprocessing.
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 5.20 %
kilosort.run_kilosort: Mem used: 20.50 % | 25.64 GB
kilosort.run_kilosort: Mem avail: 99.53 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 59.77 % | 18.73 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 9.03 % | 2.83 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: Generating drift plots ...
kilosort.run_kilosort:
kilosort.run_kilosort: Extracting spikes using templates
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
kilosort.spikedetect: Number of universal templates: 1532
kilosort.spikedetect: Detecting spikes...
100%|████████████████████████████| 1597/1597 [10:03<00:00, 2.65it/s]
kilosort.run_kilosort: 53723905 spikes extracted in 625.32s; total 2400.81s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after spike detect (univ)
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 0.00 %
kilosort.run_kilosort: Mem used: 37.40 % | 46.82 GB
kilosort.run_kilosort: Mem avail: 78.34 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 100.00 %
kilosort.run_kilosort: GPU memory: 59.84 % | 18.75 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.03 % | 0.01 / 31.33 GB
kilosort.run_kilosort: Max alloc: 23.45 % | 7.35 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: First clustering
kilosort.run_kilosort: ----------------------------------------
100%|██████████████████████████████| 1/1 [1:25:50<00:00, 5150.49s/it]
kilosort.run_kilosort: 17345 clusters found, in 5211.94s; total 7612.78s
kilosort.run_kilosort:
kilosort.run_kilosort: Resource usage after first clustering
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 4.20 %
kilosort.run_kilosort: Mem used: 39.00 % | 48.80 GB
kilosort.run_kilosort: Mem avail: 76.37 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 88.41 % | 27.70 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.44 % | 0.14 / 31.33 GB
kilosort.run_kilosort: Max alloc: 63.36 % | 19.85 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort:
kilosort.run_kilosort: Extracting spikes using cluster waveforms
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: Out of memory error, printing performance...
Traceback (most recent call last):
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/run_kilosort.py", line 302, in _sort
st,tF, Wall0, clu0 = detect_spikes(
~~~~~~~~~~~~~^
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
clear_cache=clear_cache, verbose=verbose_log
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/run_kilosort.py", line 842, in detect_spikes
st, tF, ops = template_matching.extract(
~~~~~~~~~~~~~~~~~~~~~~~~~^
ops, bfile, Wall3, device=device, progress_bar=progress_bar
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/template_matching.py", line 68, in extract
ctc = prepare_matching(ops, U)
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/kilosort/template_matching.py", line 164, in prepare_matching
UtU = torch.einsum('ikl, jml -> ijkm', U, U)
File "/home/alowet/miniconda3/envs/kilosort/lib/python3.13/site-packages/torch/functional.py", line 422, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 30.02 GiB. GPU 0 has a total capacity of 31.33 GiB of which 21.80 GiB is free. Process 4058 has 3.00 GiB memory in use. Including non-PyTorch memory, this process has 2.67 GiB memory in use. Of the allocated memory 409.02 MiB is allocated by PyTorch, and 1.57 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: CPU usage: 8.80 %
kilosort.run_kilosort: Mem used: 38.90 % | 48.73 GB
kilosort.run_kilosort: Mem avail: 76.44 / 125.17 GB
kilosort.run_kilosort: ------------------------------------------------------
kilosort.run_kilosort: GPU usage: 0.00 %
kilosort.run_kilosort: GPU memory: 30.44 % | 9.54 / 31.33 GB
kilosort.run_kilosort: Allocated: 0.46 % | 0.14 / 31.33 GB
kilosort.run_kilosort: Max alloc: 1.27 % | 0.40 / 31.33 GB
kilosort.run_kilosort: ********************************************************
kilosort.run_kilosort: Encountered error in run_kilosort:

Parameters used in this run:

{'n_chan_bin': 384, 'fs': 30000, 'batch_size': 131072, 'nblocks': 5, 'Th_universal': 9, 'Th_learned': 8, 'tmin': 0, 'tmax': inf, 'nt': 61, 'shift': None, 'scale': None, 'batch_downsampling': 1, 'artifact_threshold': inf, 'nskip': 25, 'whitening_range': 32, 'highpass_cutoff': 300, 'binning_depth': 5, 'sig_interp': 20, 'drift_smoothing': [0.5, 0.5, 0.5], 'nt0min': None, 'dmin': None, 'dminx': 32, 'min_template_size': 10, 'template_sizes': 5, 'nearest_chans': 10, 'nearest_templates': 100, 'max_channel_distance': 32, 'max_peels': 100, 'templates_from_data': True, 'n_templates': 6, 'n_pcs': 6, 'Th_single_ch': 6, 'acg_threshold': 0.2, 'ccg_threshold': 0.25, 'cluster_neighbors': 10, 'cluster_downsampling': 20, 'max_cluster_subset': 25000, 'x_centers': None, 'cluster_init_seed': 5, 'duplicate_spike_ms': 0.25, 'position_limit': 100}

When clustering produces many templates (e.g. 17K+), prepare_matching
computes UtU and ctc tensors that scale as O(N^2) and can exceed GPU
memory (30+ GiB for 17K templates). This adds an adaptive fallback
that computes ctc columns on-the-fly in chunks of 128 when the full
tensor would not fit, keeping peak memory around 1.4 GiB per chunk.
The original precomputed path is preserved for smaller template counts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@alowet alowet closed this Mar 9, 2026
@alowet alowet deleted the fix-oom-template-matching branch March 9, 2026 21:40
Add gc.collect() + empty_cache() in detect_spikes (after clustering)
and in extract() to reclaim reserved GPU memory before heavy
allocations. Use in-place ops for Cf computation to avoid allocating
multiple full-size temporary tensors, and delete Cf immediately after
extracting max values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@alowet alowet reopened this Mar 9, 2026
alowet and others added 2 commits March 9, 2026 17:31
The tqdm bar previously wrapped only the outer xcent loop, which
often has just 1 element (showing "1/1"). Now it tracks progress
over all xcent * ycent centers, giving useful ETA estimates during
the 15+ minute clustering step.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When a batch has only 1 spike, the interleaved split iY[1::2] is
empty, causing _compute_ctc_columns to return from an empty list.
Skip the computation when the slice is empty.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@jacobpennington
Copy link
Copy Markdown
Collaborator

Can you please also upload kilosort4.log for one of the times you encountered this issue? Something like this approach might be useful for a much larger probe, but you should have nowhere near that many templates on a Neuropixels 1 probe.

@alowet
Copy link
Copy Markdown
Author

alowet commented Mar 10, 2026

Certainly, here is one from Windows, although note that most of my testing has been on Ubuntu 24.04.
kilosort4windows.log

@alowet
Copy link
Copy Markdown
Author

alowet commented Mar 10, 2026

Also happy to share the raw data if you prefer that. Like I said, there may be some motion artifacts, but I checked in the SpikeGadgets software and there's nothing very out of the ordinary. Please note that I only get the error on certain sessions; others seem to sort fine.

@jacobpennington
Copy link
Copy Markdown
Collaborator

jacobpennington commented Mar 10, 2026

Are you using the same settings for all sessions? One thing I would recommend changing from the log you attached is setting cluster_downsampling = 20. The default was changed to cluster_downsampling = 1 for a few versions, but we reverted it back because we saw issues with oversplitting and spurious clusters for some datasets. That may be what's going on here.

Regardless, getting the GPU computation you're referencing to proceed would not fix the issue. If you're getting ~20,000 clusters on a Neuropixels 1 probe, something else has already gone wrong.

@alowet
Copy link
Copy Markdown
Author

alowet commented Mar 10, 2026

I have played with the settings a bit to see if it works. Some used versions where cluster_downsampling = 20, others used 1 (as in the attached log with v4.1.1). I will start explicitly setting it to 20.

It must be something about the noise structure in the recording in that case. I will investigate and let you know if I find a fix. Thank you!

alowet and others added 2 commits March 9, 2026 20:56
With 14K+ templates and batch_size=131072, the full Cf tensor
(clamp(B,0)^2/nm) requires ~7.25 GB. Instead of allocating it all
at once, compute Cfmax/imax in chunks sized to fit 15% of free GPU
memory. Also call empty_cache() before B allocation to reclaim
reserved-but-unallocated PyTorch memory.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- prepare_matching: ctc precomputation threshold now accounts for B
  tensor size (N * NTbuff * 4) rather than using arbitrary 40% of
  free memory
- _compute_ctc_columns: chunk size now computed from free GPU memory
  and per-column byte cost (UtU_chunk + ctc_chunk) rather than
  hardcoded 128

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@alowet
Copy link
Copy Markdown
Author

alowet commented Mar 10, 2026

Just a note to say that the updated fork now fixes the issues I had by computing template-matching batchwise. Specifically:

  1. Chunked ctc computation in prepare_matching (ea014f4)

The ctc tensor (shape N_templates × N_templates × (2*nt+1)) scales as O(N²) and requires ~108 GB for 15K templates. When it won't fit in 40% of free
GPU memory, we skip precomputing it and instead compute columns on-the-fly during peeling via _compute_ctc_columns, processing templates in chunks of
128.

  1. GPU memory cleanup before template matching (d5b2456)
  • Added gc.collect() + torch.cuda.empty_cache() in extract() before heavy allocations to reclaim memory from prior pipeline stages
  • Made Cf computation use in-place ops (pow_, div_) and del Cf after extracting max values
  1. Informative clustering progress bar (ba260a8)

Changed the "First clustering" tqdm bar to track all len(xcent) * len(ycent) spatial centers instead of just showing 0/1.

  1. Fix IndexError for single-spike batches (38d812e)

Added elif len(iY[j::n]) > 0: guard before calling _compute_ctc_columns — when a batch has only 1 spike, slicing iY[1::2] yields an empty tensor
which caused an IndexError.

  1. Chunked Cf computation to avoid second OOM (7c865a5)

Even after the ctc fix, torch.clamp(B, min=0) allocates a full copy of B (~7.25 GB for 14826 templates × 131072 time points). Replaced with chunked
computation: Cfmax/imax are computed over template-dimension chunks sized to 15% of free GPU memory, never materializing the full Cf tensor. Also
added torch.cuda.empty_cache() before B allocation to reclaim reserved-but-unallocated PyTorch memory.

Tested on

  • 384-channel Neuropixel 1.0 recordings producing 53.7M spikes and 17K+ clusters
  • NVIDIA RTX 5090 (31.33 GB) — previously OOM'd, now completes successfully

Of course, this runs much more slowly whenever we have to deal with 17K+ clusters. However, the final, merged results yield a reasonable number of "good" units (~300) that are visually good in Phy.

I now am testing tighter bounds on the available memory to see if I can speed up the computation slightly:

prepare_matching threshold: Now computes ctc_bytes + B_bytes * 2 < free_mem * 0.8 — it accounts for B (the largest competing tensor, N × NTbuff × 4)
needing to coexist with ctc, with 2× B as headroom for intermediates during B computation (conv1d + einsum).

_compute_ctc_columns chunk size: Now computed as 0.15 * free_mem / bytes_per_col where bytes_per_col = N * (n_pcs² + (2*nt+1)) * 4 — the actual
per-column cost of UtU_chunk and ctc_chunk. This adapts to both the number of templates and available memory.

Will confirm once this is tested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants