2424
2525np .seterr (divide = "ignore" , invalid = "ignore" )
2626zarr , has_zarr = optional_import ("zarr" )
27+ numcodecs , has_numcodecs = optional_import ("numcodecs" )
2728
2829TENSOR_4x4 = torch .randint (low = 0 , high = 255 , size = (2 , 3 , 4 , 4 ), dtype = torch .float32 )
2930TENSOR_4x4_WITH_NAN = TENSOR_4x4 .clone ()
128129 TENSOR_4x4 ,
129130]
130131# with both value_dtype, count_dtype set to double precision
131- TEST_CASE_8_OUTPUT_DTYPE = [
132- dict (merged_shape = TENSOR_4x4 .shape , output_dtype = np .float64 ),
132+ TEST_CASE_8_DTYPE = [
133+ dict (merged_shape = TENSOR_4x4 .shape , dtype = np .float64 ),
133134 [
134135 (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
135136 (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
196197]
197198
198199
200+ # test for LZ4 compressor
201+ TEST_CASE_13_COMPRESSOR_LZ4 = [
202+ dict (merged_shape = TENSOR_4x4 .shape , compressor = "LZ4" ),
203+ [
204+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
205+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
206+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
207+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
208+ ],
209+ TENSOR_4x4 ,
210+ ]
211+
212+ # test for pickle compressor
213+ TEST_CASE_14_COMPRESSOR_PICKLE = [
214+ dict (merged_shape = TENSOR_4x4 .shape , compressor = "Pickle" ),
215+ [
216+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
217+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
218+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
219+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
220+ ],
221+ TENSOR_4x4 ,
222+ ]
223+
224+ # test for LZMA compressor
225+ TEST_CASE_15_COMPRESSOR_LZMA = [
226+ dict (merged_shape = TENSOR_4x4 .shape , compressor = "LZMA" ),
227+ [
228+ (TENSOR_4x4 [..., :2 , :2 ], (0 , 0 )),
229+ (TENSOR_4x4 [..., :2 , 2 :], (0 , 2 )),
230+ (TENSOR_4x4 [..., 2 :, :2 ], (2 , 0 )),
231+ (TENSOR_4x4 [..., 2 :, 2 :], (2 , 2 )),
232+ ],
233+ TENSOR_4x4 ,
234+ ]
235+
236+
237+ @unittest .skipIf (not has_zarr or not has_numcodecs , "Requires zarr (and numcodecs) packages.)" )
199238class ZarrAvgMergerTests (unittest .TestCase ):
200239 @parameterized .expand (
201240 [
@@ -207,14 +246,26 @@ class ZarrAvgMergerTests(unittest.TestCase):
207246 TEST_CASE_5_VALUE_DTYPE ,
208247 TEST_CASE_6_COUNT_DTYPE ,
209248 TEST_CASE_7_COUNT_VALUE_DTYPE ,
210- TEST_CASE_8_OUTPUT_DTYPE ,
249+ TEST_CASE_8_DTYPE ,
211250 TEST_CASE_9_LARGER_SHAPE ,
212251 TEST_CASE_10_DIRECTORY_STORE ,
213252 TEST_CASE_11_MEMORY_STORE ,
214253 TEST_CASE_12_CHUNKS ,
254+ TEST_CASE_13_COMPRESSOR_LZ4 ,
255+ TEST_CASE_14_COMPRESSOR_PICKLE ,
256+ TEST_CASE_15_COMPRESSOR_LZMA ,
215257 ]
216258 )
217- def test_avg_merger_patches (self , arguments , patch_locations , expected ):
259+ def test_zarr_avg_merger_patches (self , arguments , patch_locations , expected ):
260+ if "compressor" in arguments :
261+ if arguments ["compressor" ] != "default" :
262+ arguments ["compressor" ] = zarr .codec_registry [arguments ["compressor" ].lower ()]()
263+ if "value_compressor" in arguments :
264+ if arguments ["value_compressor" ] != "default" :
265+ arguments ["value_compressor" ] = zarr .codec_registry [arguments ["value_compressor" ].lower ()]()
266+ if "count_compressor" in arguments :
267+ if arguments ["count_compressor" ] != "default" :
268+ arguments ["count_compressor" ] = zarr .codec_registry [arguments ["count_compressor" ].lower ()]()
218269 merger = ZarrAvgMerger (** arguments )
219270 for pl in patch_locations :
220271 merger .aggregate (pl [0 ], pl [1 ])
@@ -228,13 +279,13 @@ def test_avg_merger_patches(self, arguments, patch_locations, expected):
228279 # check if the result is matching the expectation
229280 assert_allclose (output [:], expected .numpy ())
230281
231- def test_avg_merger_finalized_error (self ):
282+ def test_zarr_avg_merger_finalized_error (self ):
232283 with self .assertRaises (ValueError ):
233284 merger = ZarrAvgMerger (merged_shape = (1 , 3 , 2 , 3 ))
234285 merger .finalize ()
235286 merger .aggregate (torch .zeros (1 , 3 , 2 , 2 ), (3 , 3 ))
236287
237- def test_avg_merge_none_merged_shape_error (self ):
288+ def test_zarr_avg_merge_none_merged_shape_error (self ):
238289 with self .assertRaises (ValueError ):
239290 ZarrAvgMerger (merged_shape = None )
240291
0 commit comments