@@ -392,13 +392,14 @@ def test_torch_formatter_image(self):
392392 formatter = TorchFormatter (features = Features ({"image" : Image ()}))
393393 row = formatter .format_row (pa_table )
394394 self .assertEqual (row ["image" ].dtype , torch .uint8 )
395- self .assertEqual (row ["image" ].shape , (480 , 640 , 3 ))
395+ # torch uses CHW format contrary to numpy which uses HWC
396+ self .assertEqual (row ["image" ].shape , (3 , 480 , 640 ))
396397 col = formatter .format_column (pa_table )
397398 self .assertEqual (col .dtype , torch .uint8 )
398- self .assertEqual (col .shape , (2 , 480 , 640 , 3 ))
399+ self .assertEqual (col .shape , (2 , 3 , 480 , 640 ))
399400 batch = formatter .format_batch (pa_table )
400401 self .assertEqual (batch ["image" ].dtype , torch .uint8 )
401- self .assertEqual (batch ["image" ].shape , (2 , 480 , 640 , 3 ))
402+ self .assertEqual (batch ["image" ].shape , (2 , 3 , 480 , 640 ))
402403
403404 # different dimensions
404405 pa_table = pa .table (
@@ -407,15 +408,15 @@ def test_torch_formatter_image(self):
407408 formatter = TorchFormatter (features = Features ({"image" : Image ()}))
408409 row = formatter .format_row (pa_table )
409410 self .assertEqual (row ["image" ].dtype , torch .uint8 )
410- self .assertEqual (row ["image" ].shape , (480 , 640 , 3 ))
411+ self .assertEqual (row ["image" ].shape , (3 , 480 , 640 ))
411412 col = formatter .format_column (pa_table )
412413 self .assertIsInstance (col , list )
413414 self .assertEqual (col [0 ].dtype , torch .uint8 )
414- self .assertEqual (col [0 ].shape , (480 , 640 , 3 ))
415+ self .assertEqual (col [0 ].shape , (3 , 480 , 640 ))
415416 batch = formatter .format_batch (pa_table )
416417 self .assertIsInstance (batch ["image" ], list )
417418 self .assertEqual (batch ["image" ][0 ].dtype , torch .uint8 )
418- self .assertEqual (batch ["image" ][0 ].shape , (480 , 640 , 3 ))
419+ self .assertEqual (batch ["image" ][0 ].shape , (3 , 480 , 640 ))
419420
420421 @require_torch
421422 @require_sndfile
0 commit comments