diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8ae75f84c5b..559d89289ea 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -85,14 +85,8 @@ def to_tensor(pic): img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) - # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK - if pic.mode == 'YCbCr': - nchannel = 3 - elif pic.mode == 'I;16': - nchannel = 1 - else: - nchannel = len(pic.mode) - img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous()