Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cellpose/contrib/cluster_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from pathlib import Path
import subprocess
import pickle

import zarr
from tifffile import imread
from pooch import retrieve
from cellpose.contrib.distributed_segmentation import numpy_array_to_zarr

from cellpose.contrib.distributed_segmentation import distributed_eval
from cellpose.contrib.distributed_segmentation import SlurmCluster, janeliaLSFCluster

def main():
## PARAMETERS
# Compute node-accessible directory for input zarr dataset and outputs
output_dir = Path() / 'outputs'
input_zarr_path = output_dir / 'input.zarr'
output_zarr_path = output_dir / 'segmentation.zarr'
output_bbox_pkl = output_dir / 'bboxes.pkl'

# Cluster parameters (example: https://docs.mpcdf.mpg.de/doc/computing/viper-gpu-user-guide.html)
cluster_kwargs = {
'job_cpu': 2, # number of CPUs per GPU worker
'ncpus':1, # threads requested per GPU worker
'min_workers':1, # min number of workers based on expected workload
'max_workers':16, # max number of workers based on expected workload
'walltime': '1:00:00', # available runtime for each GPU worker for cluster scheduler (Slurm, LSF)
'queue': 'apu', # queue/ partition name for single GPU worker *
'interface': 'ib0', # interface name for compute-node communication *
'local_directory': '/tmp', # compute node local temporary directory *
'job_extra_directives': [ # extra directives for scheduler (here: Slurm) *
'--constraint apu',
'--gres gpu:1',
],
}
# * Ask your cluster support staff for assistance

# Cellpose parameters
model_kwargs = {'gpu':True}
eval_kwargs = {
'z_axis':0,
'do_3D':True,
}

# Optional: Crop data to reduce runtime for this test case
crop = (slice(0, 221), slice(1024,2048), slice(1024,2048))


## DATA PREPARATION
# here: DAPI-stained human gastruloid by Zhiyuan Yu (https://zenodo.org/records/17590053)
if not input_zarr_path.exists():
print('Download test data')
fname = retrieve(
url="https://zenodo.org/records/17590053/files/2d_gastruloid.tif?download=1",
known_hash="8ac2d944882268fbaebdfae5f7c18e4d20fdab024db2f9f02f4f45134b936872",
path = Path.home() / '.cellpose' / 'data',
progressbar=True,
)
data_numpy = imread(fname)[crop]

print(f'Convert to {data_numpy.shape} zarr array')
data_zarr = numpy_array_to_zarr(input_zarr_path, data_numpy, chunks=(256, 256, 256))
print(f'Input stored in {input_zarr_path}')
del data_numpy
else:
print(f'Read input data from {input_zarr_path}')
data_zarr = zarr.open(input_zarr_path)

## EVALUATION
# Guess cluster type by checking for cluster submission commands
if subprocess.getstatusoutput('sbatch -h')[0] == 0:
print('Slurm sbatch command detected -> use SlurmCluster')
cluster = SlurmCluster(**cluster_kwargs)
elif subprocess.getstatusoutput('bsub -h')[0] == 0:
print('LSF bsub command detected -> use janeliaLSFCLuster')
cluster = janeliaLSFCluster(**cluster_kwargs)
else:
cluster = None
## Note in case you want to test without a cluster scheduler use:
#from cellpose.contrib.distributed_segmentation import myLocalCluster
#cluster = myLocalCluster(**{
# 'n_workers': 1, # if you only have 1 gpu, then 1 worker is the right choice
# 'ncpus': 8,
# 'memory_limit':'64GB',
# 'threads_per_worker':1,
#})

if cluster is None:
raise Exception(
"Neither SLURM nor LFS cluster detected. "
"Currently, this script only supports SLURM or LSF cluster scheduler. "
"You have two options:"
"\n * Either use `distributed_eval` without the `cluster` but with the `cluster_kwargs` argument to start a local cluster on your machine"
"\n * or raise a feature request at https://github.com/MouseLand/cellpose/issues."
)

# Start computation
segments, boxes = distributed_eval(
input_zarr = data_zarr,
blocksize = (256, 256, 256),
write_path = str(output_zarr_path),
model_kwargs = model_kwargs,
eval_kwargs = eval_kwargs,
cluster = cluster,
)

# Save bounding boxes on disk
with open(output_bbox_pkl, 'wb') as f:
pickle.dump(boxes, f)

print(f'Segmentation saved in {str(output_zarr_path)}')
print(f'Object bounding boxes saved in {str(output_bbox_pkl)}')

if __name__ == '__main__':
main()
148 changes: 132 additions & 16 deletions cellpose/contrib/distributed_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,111 @@ def __exit__(self, exc_type, exc_value, traceback):
self.client.close()
super().__exit__(exc_type, exc_value, traceback)

class SlurmCluster(dask_jobqueue.SLURMCluster):
"""
This is a thin wrapper extending dask_jobqueue.SLURMCluster,
which in turn extends dask.distributed.SpecCluster. This wrapper
sets configs before the cluster or workers are initialized. This is
an adaptive cluster and will scale the number of workers, between user
specified limits, based on the number of pending tasks.

For a full list of arguments see
https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.SLURMCluster.html

Most users will only need to specify:
ncpus (the number of cpu cores per worker)
min_workers
max_workers
"""

def __init__(
self,
ncpus,
min_workers,
max_workers,
config={},
config_name=DEFAULT_CONFIG_FILENAME,
persist_config=False,
local_directory = f"/scratch/{getpass.getuser()}/",
job_script_prologue = [],
**kwargs
):

# store all args in case needed later
self.locals_store = {**locals()}

# config
self.config_name = config_name
self.persist_config = persist_config
config_defaults = {
'temporary-directory':local_directory,
'distributed.comm.timeouts.connect':'180s',
'distributed.comm.timeouts.tcp':'360s',
}
config = {**config_defaults, **config}
_modify_dask_config(config, config_name)

# threading is best in low level libraries
job_script_prologue = [
f"export MKL_NUM_THREADS={2*ncpus}",
f"export NUM_MKL_THREADS={2*ncpus}",
f"export OPENBLAS_NUM_THREADS={2*ncpus}",
f"export OPENMP_NUM_THREADS={2*ncpus}",
f"export OMP_NUM_THREADS={2*ncpus}",
] + job_script_prologue

# set log directories
if "log_directory" not in kwargs:
log_dir = f"{os.getcwd()}/dask_worker_logs_{os.getpid()}/"
pathlib.Path(log_dir).mkdir(parents=False, exist_ok=True)
kwargs["log_directory"] = log_dir

# construct
super().__init__(
processes=1,
cores=ncpus,
memory=str(15*ncpus)+'GB',
job_script_prologue=job_script_prologue,
**kwargs,
)
self.client = distributed.Client(self)
print("Cluster dashboard link: ", self.dashboard_link)

# set adaptive cluster bounds
self.adapt_cluster(min_workers, max_workers)


def __enter__(self): return self
def __exit__(self, exc_type, exc_value, traceback):
if not self.persist_config:
_remove_config_file(self.config_name)
self.client.close()
super().__exit__(exc_type, exc_value, traceback)


def adapt_cluster(self, min_workers, max_workers):
_ = self.adapt(
minimum_jobs=min_workers,
maximum_jobs=max_workers,
interval='10s',
wait_count=6,
)


def change_worker_attributes(
self,
min_workers,
max_workers,
**kwargs,
):
"""WARNING: this function is dangerous if you don't know what
you're doing. Don't call this unless you know exactly what
this does."""
self.scale(0)
for k, v in kwargs.items():
self.new_spec['options'][k] = v
self.adapt_cluster(min_workers, max_workers)


class janeliaLSFCluster(dask_jobqueue.LSFCluster):
"""
Expand Down Expand Up @@ -349,9 +454,6 @@ def create_or_pass_cluster(*args, **kwargs):
"Either cluster or cluster_kwargs must be defined"
if not 'cluster' in kwargs:
cluster_constructor = myLocalCluster
F = lambda x: x in kwargs['cluster_kwargs']
if F('ncpus') and F('min_workers') and F('max_workers'):
cluster_constructor = janeliaLSFCluster
with cluster_constructor(**kwargs['cluster_kwargs']) as cluster:
kwargs['cluster'] = cluster
return func(*args, **kwargs)
Expand Down Expand Up @@ -530,6 +632,7 @@ def remove_overlaps(array, crop, overlap, blocksize):
and can be removed after segmentation is complete
reslice array to remove the overlaps"""
crop_trimmed = list(crop)

for axis in range(array.ndim):
if crop[axis].start != 0:
slc = [slice(None),]*array.ndim
Expand Down Expand Up @@ -748,10 +851,15 @@ class in this module. If you are running on the Janelia LSF cluster, see
output_zarr=temp_zarr,
worker_logs_directory=str(worker_logs_dir),
)

print('Gather data in host process')
results = cluster.client.gather(futures)
if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
cluster.scale(0)

#print('Scale cluster down')
#if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
# cluster.scale(0)

print('Process results locally')
faces, boxes_, box_ids_ = list(zip(*results))
boxes = [box for sublist in boxes_ for box in sublist]
box_ids = np.concatenate(box_ids_).astype(int) # unsure how but without cast these are float64
Expand All @@ -760,27 +868,35 @@ class in this module. If you are running on the Janelia LSF cluster, see
new_labeling_path = temporary_directory + '/new_labeling.npy'
np.save(new_labeling_path, new_labeling)

# stitching step is cheap, we should release gpus and use small workers
if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
cluster.change_worker_attributes(
min_workers=cluster.locals_store['min_workers'],
max_workers=cluster.locals_store['max_workers'],
ncpus=1,
memory="15GB",
mem=int(15e9),
queue=None,
job_extra_directives=[],
)
#print('Change worker attributes')
## stitching step is cheap, we should release gpus and use small workers
#if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
# cluster.change_worker_attributes(
# min_workers=cluster.locals_store['min_workers'],
# max_workers=cluster.locals_store['max_workers'],
# cores=1,
# memory="15GB",
# #mem=int(15e9),
# queue="CPU",
# job_extra_directives=[],
# )

print('Use dask array to relabel segmentations')
segmentation_da = dask.array.from_zarr(temp_zarr)
print('Map dask to blocks')
relabeled = dask.array.map_blocks(
lambda block: np.load(new_labeling_path)[block],
segmentation_da,
dtype=np.uint32,
chunks=segmentation_da.chunks,
)
print("Write output")
dask.array.to_zarr(relabeled, write_path, overwrite=True)
# TODO(erjel): Scale cluster down again?

print("Merge bboxes")
merged_boxes = merge_all_boxes(boxes, new_labeling[box_ids])
print("Merge done")
return zarr.open(write_path, mode='r'), merged_boxes


Expand Down
31 changes: 31 additions & 0 deletions environment-rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: cellpose
dependencies:
- python==3.9.23
- pip
- pip:
- --index-url https://download.pytorch.org/whl/rocm6.4
- --extra-index-url https://pypi.org/simple
- qtpy
# - PyQt5.sip
- numpy>=1.20.0
- scipy
- torch>=1.6
- opencv-python-headless
- pyqtgraph>=0.11.0rc0
- natsort
- google-cloud-storage
- tqdm
- tifffile
- fastremap
- cellpose
- roifile
- pyqt5
- dask
- distributed
- dask-image
- pyyaml
- zarr
- dask_jobqueue
- bokeh
- fill-voids
- pooch
5 changes: 4 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: cellpose
dependencies:
- python==3.8.5
- python==3.9.23
- pip
- pip:
- --index-url https://download.pytorch.org/whl/cu128 # [win] # Only for Windows: Use non-PyPi torch wheels for CUDA support
- --extra-index-url https://pypi.org/simple # [win] # Only for Windows: Use PyPi for everything else
- qtpy
# - PyQt5.sip
- numpy>=1.20.0
Expand All @@ -26,4 +28,5 @@ dependencies:
- dask_jobqueue
- bokeh
- fill-voids
- pooch