Skip to content

Commit a02997d

Browse files
varunneallhoestq
andauthored
Update torch_formatter.py (#6402)
* Update torch_formatter.py Transpose images to (C, H, W) instead of (H, W, C). * handle monochromat PIL image * Update test_formatting.py --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 4591ac1 commit a02997d

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/datasets/formatting/torch_formatter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def _tensorize(self, value):
7171

7272
if isinstance(value, PIL.Image.Image):
7373
value = np.asarray(value)
74+
if value.ndim == 2:
75+
value = value[:, :, np.newaxis]
76+
77+
value = value.transpose((2, 0, 1))
7478
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
7579

7680
def _recursive_tensorize(self, data_struct):

tests/test_formatting.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,14 @@ def test_torch_formatter_image(self):
392392
formatter = TorchFormatter(features=Features({"image": Image()}))
393393
row = formatter.format_row(pa_table)
394394
self.assertEqual(row["image"].dtype, torch.uint8)
395-
self.assertEqual(row["image"].shape, (480, 640, 3))
395+
# torch uses CHW format contrary to numpy which uses HWC
396+
self.assertEqual(row["image"].shape, (3, 480, 640))
396397
col = formatter.format_column(pa_table)
397398
self.assertEqual(col.dtype, torch.uint8)
398-
self.assertEqual(col.shape, (2, 480, 640, 3))
399+
self.assertEqual(col.shape, (2, 3, 480, 640))
399400
batch = formatter.format_batch(pa_table)
400401
self.assertEqual(batch["image"].dtype, torch.uint8)
401-
self.assertEqual(batch["image"].shape, (2, 480, 640, 3))
402+
self.assertEqual(batch["image"].shape, (2, 3, 480, 640))
402403

403404
# different dimensions
404405
pa_table = pa.table(
@@ -407,15 +408,15 @@ def test_torch_formatter_image(self):
407408
formatter = TorchFormatter(features=Features({"image": Image()}))
408409
row = formatter.format_row(pa_table)
409410
self.assertEqual(row["image"].dtype, torch.uint8)
410-
self.assertEqual(row["image"].shape, (480, 640, 3))
411+
self.assertEqual(row["image"].shape, (3, 480, 640))
411412
col = formatter.format_column(pa_table)
412413
self.assertIsInstance(col, list)
413414
self.assertEqual(col[0].dtype, torch.uint8)
414-
self.assertEqual(col[0].shape, (480, 640, 3))
415+
self.assertEqual(col[0].shape, (3, 480, 640))
415416
batch = formatter.format_batch(pa_table)
416417
self.assertIsInstance(batch["image"], list)
417418
self.assertEqual(batch["image"][0].dtype, torch.uint8)
418-
self.assertEqual(batch["image"][0].shape, (480, 640, 3))
419+
self.assertEqual(batch["image"][0].shape, (3, 480, 640))
419420

420421
@require_torch
421422
@require_sndfile

0 commit comments

Comments
 (0)