Skip to content

Commit 3295cff

Browse files
committed
add flexibility to define compressors for all zarr arrays
Signed-off-by: Behrooz <[email protected]>
1 parent e0c35eb commit 3295cff

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

monai/inferers/merger.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,15 @@ class ZarrAvgMerger(Merger):
208208
merged_shape: the shape of the tensor required to merge the patches.
209209
cropped_shape: the shape of the final merged output tensor.
210210
If not provided, it will be the same as `merged_shape`.
211-
output_dtype: the dtype for the final result. Default is `float32`.
211+
dtype: the dtype for the final merged result. Default is `float32`.
212212
value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.
213213
count_dtype: the dtype for sample counting tensor. Default is `uint8`.
214214
store: the zarr store to save the final results. Default is "merged.zarr".
215215
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
216216
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
217217
compressor: the compressor for final merged zarr array. Default is "default".
218-
The compressor for temporary zarr arrays (values and counts) will be set to None.
218+
value_compressor: the compressor for value aggregating zarr array. Default is None.
219+
count_compressor: the compressor for sample counting zarr array. Default is None.
219220
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
220221
If True, chunk shape will be guessed from `shape` and `dtype`.
221222
If False, ir will be set to `shape`, i.e., single chunk for the whole array.
@@ -226,26 +227,30 @@ def __init__(
226227
self,
227228
merged_shape: Sequence[int],
228229
cropped_shape: Sequence[int] | None = None,
229-
output_dtype: np.dtype | str = "float32",
230+
dtype: np.dtype | str = "float32",
230231
value_dtype: np.dtype | str = "float32",
231232
count_dtype: np.dtype | str = "uint8",
232233
store: zarr.storage.Store | str = "merged.zarr",
233234
value_store: zarr.storage.Store | str | None = None,
234235
count_store: zarr.storage.Store | str | None = None,
235236
compressor: str = "default",
237+
value_compressor: str | None = None,
238+
count_compressor: str | None = None,
236239
chunks: Sequence[int] | bool = True,
237240
) -> None:
238241
super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)
239242
if not self.merged_shape:
240243
raise ValueError(f"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.")
241-
self.output_dtype = output_dtype
244+
self.output_dtype = dtype
242245
self.value_dtype = value_dtype
243246
self.count_dtype = count_dtype
244247
self.store = store
245248
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
246249
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
247250
self.chunks = chunks
248251
self.compressor = compressor
252+
self.value_compressor = value_compressor
253+
self.count_compressor = count_compressor
249254
self.output = zarr.empty(
250255
shape=self.merged_shape,
251256
chunks=self.chunks,
@@ -258,15 +263,15 @@ def __init__(
258263
shape=self.merged_shape,
259264
chunks=self.chunks,
260265
dtype=self.value_dtype,
261-
compressor=None,
266+
compressor=self.value_compressor,
262267
store=self.value_store,
263268
overwrite=True,
264269
)
265270
self.counts = zarr.zeros(
266271
shape=self.merged_shape,
267272
chunks=self.chunks,
268273
dtype=self.count_dtype,
269-
compressor=None,
274+
compressor=self.count_compressor,
270275
store=self.count_store,
271276
overwrite=True,
272277
)

tests/test_zarr_avg_merger.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
np.seterr(divide="ignore", invalid="ignore")
2626
zarr, has_zarr = optional_import("zarr")
27+
numcodecs, has_numcodecs = optional_import("numcodecs")
2728

2829
TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)
2930
TENSOR_4x4_WITH_NAN = TENSOR_4x4.clone()
@@ -128,8 +129,8 @@
128129
TENSOR_4x4,
129130
]
130131
# with both value_dtype, count_dtype set to double precision
131-
TEST_CASE_8_OUTPUT_DTYPE = [
132-
dict(merged_shape=TENSOR_4x4.shape, output_dtype=np.float64),
132+
TEST_CASE_8_DTYPE = [
133+
dict(merged_shape=TENSOR_4x4.shape, dtype=np.float64),
133134
[
134135
(TENSOR_4x4[..., :2, :2], (0, 0)),
135136
(TENSOR_4x4[..., :2, 2:], (0, 2)),
@@ -196,6 +197,44 @@
196197
]
197198

198199

200+
# test for LZ4 compressor
201+
TEST_CASE_13_COMPRESSOR_LZ4 = [
202+
dict(merged_shape=TENSOR_4x4.shape, compressor="LZ4"),
203+
[
204+
(TENSOR_4x4[..., :2, :2], (0, 0)),
205+
(TENSOR_4x4[..., :2, 2:], (0, 2)),
206+
(TENSOR_4x4[..., 2:, :2], (2, 0)),
207+
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
208+
],
209+
TENSOR_4x4,
210+
]
211+
212+
# test for pickle compressor
213+
TEST_CASE_14_COMPRESSOR_PICKLE = [
214+
dict(merged_shape=TENSOR_4x4.shape, compressor="Pickle"),
215+
[
216+
(TENSOR_4x4[..., :2, :2], (0, 0)),
217+
(TENSOR_4x4[..., :2, 2:], (0, 2)),
218+
(TENSOR_4x4[..., 2:, :2], (2, 0)),
219+
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
220+
],
221+
TENSOR_4x4,
222+
]
223+
224+
# test for LZMA compressor
225+
TEST_CASE_15_COMPRESSOR_LZMA = [
226+
dict(merged_shape=TENSOR_4x4.shape, compressor="LZMA"),
227+
[
228+
(TENSOR_4x4[..., :2, :2], (0, 0)),
229+
(TENSOR_4x4[..., :2, 2:], (0, 2)),
230+
(TENSOR_4x4[..., 2:, :2], (2, 0)),
231+
(TENSOR_4x4[..., 2:, 2:], (2, 2)),
232+
],
233+
TENSOR_4x4,
234+
]
235+
236+
237+
@unittest.skipIf(not has_zarr or not has_numcodecs, "Requires zarr (and numcodecs) packages.)")
199238
class ZarrAvgMergerTests(unittest.TestCase):
200239
@parameterized.expand(
201240
[
@@ -207,14 +246,26 @@ class ZarrAvgMergerTests(unittest.TestCase):
207246
TEST_CASE_5_VALUE_DTYPE,
208247
TEST_CASE_6_COUNT_DTYPE,
209248
TEST_CASE_7_COUNT_VALUE_DTYPE,
210-
TEST_CASE_8_OUTPUT_DTYPE,
249+
TEST_CASE_8_DTYPE,
211250
TEST_CASE_9_LARGER_SHAPE,
212251
TEST_CASE_10_DIRECTORY_STORE,
213252
TEST_CASE_11_MEMORY_STORE,
214253
TEST_CASE_12_CHUNKS,
254+
TEST_CASE_13_COMPRESSOR_LZ4,
255+
TEST_CASE_14_COMPRESSOR_PICKLE,
256+
TEST_CASE_15_COMPRESSOR_LZMA,
215257
]
216258
)
217-
def test_avg_merger_patches(self, arguments, patch_locations, expected):
259+
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
260+
if "compressor" in arguments:
261+
if arguments["compressor"] != "default":
262+
arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]()
263+
if "value_compressor" in arguments:
264+
if arguments["value_compressor"] != "default":
265+
arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]()
266+
if "count_compressor" in arguments:
267+
if arguments["count_compressor"] != "default":
268+
arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]()
218269
merger = ZarrAvgMerger(**arguments)
219270
for pl in patch_locations:
220271
merger.aggregate(pl[0], pl[1])
@@ -228,13 +279,13 @@ def test_avg_merger_patches(self, arguments, patch_locations, expected):
228279
# check if the result is matching the expectation
229280
assert_allclose(output[:], expected.numpy())
230281

231-
def test_avg_merger_finalized_error(self):
282+
def test_zarr_avg_merger_finalized_error(self):
232283
with self.assertRaises(ValueError):
233284
merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3))
234285
merger.finalize()
235286
merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3))
236287

237-
def test_avg_merge_none_merged_shape_error(self):
288+
def test_zarr_avg_merge_none_merged_shape_error(self):
238289
with self.assertRaises(ValueError):
239290
ZarrAvgMerger(merged_shape=None)
240291

0 commit comments

Comments
 (0)