Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,14 +442,16 @@ def run(ops, st, tF, mode='template', device=torch.device('cuda'),
Nfilt = None
nearby_chans_empty = 0
nmax = 0
prog = tqdm(np.arange(len(xcent)), miniters=20 if progress_bar else None,
n_total = len(xcent) * len(ycent)
prog = tqdm(total=n_total, miniters=20 if progress_bar else None,
mininterval=10 if progress_bar else None)
t = 0
v = False

try:
for jj in prog:
for jj in np.arange(len(xcent)):
for kk in np.arange(len(ycent)):
prog.update(1)
# Get data for all templates that were closest to this x,y center.
ii = kk + jj*ycent.size
if ii not in nearest_center:
Expand Down Expand Up @@ -541,6 +543,8 @@ def run(ops, st, tF, mode='template', device=torch.device('cuda'),
logger.debug('iclust not yet assigned')
pass
raise
finally:
prog.close()

if nearby_chans_empty == total_centers:
raise ValueError(
Expand Down
6 changes: 6 additions & 0 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,12 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,

log_thread_count(logger)

# Free large intermediates before template matching
del st0, tF
import gc
gc.collect()
torch.cuda.empty_cache()

tic = time.time()
logger.info(' ')
logger.info('Extracting spikes using cluster waveforms')
Expand Down
122 changes: 94 additions & 28 deletions kilosort/template_matching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import gc
import logging

import numpy as np
import torch
import torch
from torch.nn.functional import conv1d, max_pool2d, max_pool1d
from tqdm import tqdm

Expand Down Expand Up @@ -64,8 +65,13 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
ops['iU'] = iU
nt = ops['nt']

tiwave = torch.arange(-(nt//2), nt//2+1, device=device)
ctc = prepare_matching(ops, U)
tiwave = torch.arange(-(nt//2), nt//2+1, device=device)

# Free memory from prior pipeline stages before heavy allocations
gc.collect()
torch.cuda.empty_cache()

ctc, WtW = prepare_matching(ops, U)
st = np.zeros((10**6, 3), 'float64')
tF = torch.zeros((10**6, nC , ops['settings']['n_pcs']))
k = 0
Expand All @@ -81,7 +87,7 @@ def extract(ops, bfile, U, device=torch.device('cuda'), progress_bar=None):
log_performance(logger, 'debug', f'Batch {ibatch}')

X = bfile.padded_batch_to_torch(ibatch, ops)
stt, amps, th_amps, Xres = run_matching(ops, X, U, ctc, device=device)
stt, amps, th_amps, Xres = run_matching(ops, X, U, ctc, WtW, device=device)
xfeat = Xres[iCC[:, iU[stt[:,1:2]]],stt[:,:1] + tiwave] @ ops['wPCA'].T
xfeat += amps * Ucc[:,stt[:,1]]

Expand Down Expand Up @@ -155,33 +161,77 @@ def postprocess_templates(Wall, ops, clu, st, tF, device=torch.device('cuda')):
def prepare_matching(ops, U):
nt = ops['nt']
W = ops['wPCA'].contiguous()
WtW = conv1d(W.reshape(-1, 1,nt), W.reshape(-1, 1 ,nt), padding = nt)
WtW = conv1d(W.reshape(-1, 1,nt), W.reshape(-1, 1 ,nt), padding = nt)
WtW = torch.flip(WtW, [2,])

#mu = (U**2).sum(-1).sum(-1)**.5
#U2 = U / mu.unsqueeze(-1).unsqueeze(-1)

UtU = torch.einsum('ikl, jml -> ijkm', U, U)
ctc = torch.einsum('ijkm, kml -> ijl', UtU, WtW)

return ctc


def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
N = U.shape[0]
n_pcs = U.shape[1]
ctc_bytes = N * N * (2*nt+1) * 4
try:
free_mem = torch.cuda.mem_get_info(U.device)[0]
except Exception:
free_mem = 0

# Estimate memory needed for other tensors in run_matching:
# B: N * NTbuff * 4, Xres: n_chan * NTbuff * 4 (small), plus headroom
NTbuff = ops['NTbuff']
B_bytes = N * NTbuff * 4
# ctc fits if it plus B plus safety margin fit in free memory
if ctc_bytes + B_bytes * 2 < free_mem * 0.8:
UtU = torch.einsum('ikl, jml -> ijkm', U, U)
ctc = torch.einsum('ijkm, kml -> ijl', UtU, WtW)
return ctc, None
else:
logger.info(
f'Too many templates ({N}) for precomputed ctc '
f'({ctc_bytes/1e9:.1f} GB), using chunked on-the-fly computation'
)
return None, WtW


def _compute_ctc_columns(U, WtW, sel_iY, device=None):
"""Compute ctc[:, sel_iY, :] on-the-fly in chunks to avoid O(N^2) memory."""
N = U.shape[0]
n_pcs = U.shape[1]
nt2 = WtW.shape[2] # 2*nt+1
if device is None:
device = U.device

# UtU_chunk: N * chunk * n_pcs * n_pcs * 4 bytes
# ctc_chunk: N * chunk * nt2 * 4 bytes
free_mem = torch.cuda.mem_get_info(device)[0]
bytes_per_col = N * (n_pcs * n_pcs + nt2) * 4
chunk_size = max(1, int(0.15 * free_mem / bytes_per_col))

ctc_chunks = []
for c_start in range(0, len(sel_iY), chunk_size):
c_end = min(c_start + chunk_size, len(sel_iY))
U_chunk = U[sel_iY[c_start:c_end]]
UtU_chunk = torch.einsum('ikl, jml -> ijkm', U, U_chunk)
ctc_chunk = torch.einsum('ijkm, kml -> ijl', UtU_chunk, WtW)
ctc_chunks.append(ctc_chunk)
del UtU_chunk
return torch.cat(ctc_chunks, dim=1) if len(ctc_chunks) > 1 else ctc_chunks[0]


def run_matching(ops, X, U, ctc, WtW=None, device=torch.device('cuda')):
Th = ops['Th_learned']
nt = ops['nt']
max_peels = ops['max_peels']
W = ops['wPCA'].contiguous()

nm = (U**2).sum(-1).sum(-1)
#mu = nm**.5
#mu = nm**.5
#U2 = U / mu.unsqueeze(-1).unsqueeze(-1)

# Reclaim reserved-but-unallocated memory before large B allocation
torch.cuda.empty_cache()

B = conv1d(X.unsqueeze(1), W.unsqueeze(1), padding=nt//2)
B = torch.einsum('ijk, kjl -> il', U, B)

trange = torch.arange(-nt, nt+1, device=device)
tiwave = torch.arange(-(nt//2), nt//2+1, device=device)
trange = torch.arange(-nt, nt+1, device=device)
tiwave = torch.arange(-(nt//2), nt//2+1, device=device)

st = torch.zeros((100000,2), dtype = torch.int64, device = device)
amps = torch.zeros((100000,1), dtype = torch.float, device = device)
Expand All @@ -192,17 +242,28 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
lam = 20

for t in range(max_peels):
# Cf = 2 * B - nm.unsqueeze(-1)
# Cf is shape (n_units, n_times)
Cf = torch.relu(B)**2 /nm.unsqueeze(-1)
# Cf = clamp(B,0)^2 / nm, shape (n_units, n_times)
# Compute Cfmax and imax in chunks to avoid allocating full Cf tensor
#a = 1 + lam
#b = torch.relu(B) + lam * mu.unsqueeze(-1)
#Cf = b**2 / a - lam * mu.unsqueeze(-1)**2

Cf[:, :nt] = 0
Cf[:, -nt:] = 0

Cfmax, imax = torch.max(Cf, 0)
N_templates = B.shape[0]
NT = B.shape[1]
cf_chunk_size = max(1, min(N_templates, int(0.15 * torch.cuda.mem_get_info(device)[0] / (NT * 4))))
Cfmax = torch.full((NT,), -1.0, device=device)
imax = torch.zeros(NT, dtype=torch.long, device=device)
for ci in range(0, N_templates, cf_chunk_size):
ce = min(ci + cf_chunk_size, N_templates)
Cf_chunk = torch.clamp(B[ci:ce], min=0)
Cf_chunk.pow_(2).div_(nm[ci:ce].unsqueeze(-1))
Cf_chunk[:, :nt] = 0
Cf_chunk[:, -nt:] = 0
chunk_max, chunk_imax = torch.max(Cf_chunk, 0)
better = chunk_max > Cfmax
imax[better] = chunk_imax[better] + ci
Cfmax[better] = chunk_max[better]
del Cf_chunk, chunk_max, chunk_imax
Cmax = max_pool1d(Cfmax.unsqueeze(0).unsqueeze(0), (2*nt+1), stride=1, padding=(nt))

#print(Cfmax.shape)
Expand All @@ -211,7 +272,7 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
cnd2 = torch.abs(Cmax[0,0] - Cfmax) < 1e-9
xs = torch.nonzero(cnd1 * cnd2)


if len(xs)==0:
#print('iter %d'%t)
break
Expand All @@ -230,12 +291,17 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):

k+= nsp

#amp = B[iY,iX]
#amp = B[iY,iX]

n = 2
for j in range(n):
Xres[:, iX[j::n] + tiwave] -= amp[j::n] * torch.einsum('ijk, jl -> kil', U[iY[j::n,0]], W)
B[ :, iX[j::n] + trange] -= amp[j::n] * ctc[:,iY[j::n,0],:]
if ctc is not None:
B[ :, iX[j::n] + trange] -= amp[j::n] * ctc[:,iY[j::n,0],:]
elif len(iY[j::n]) > 0:
ctc_sel = _compute_ctc_columns(U, WtW, iY[j::n, 0], device=device)
B[:, iX[j::n] + trange] -= amp[j::n] * ctc_sel
del ctc_sel

st = st[:k]
amps = amps[:k]
Expand Down