Skip to content

Commit 3f04994

Browse files
committed
Add RGB mode support for image loading in load_image function
1 parent d981306 commit 3f04994

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchtune/data/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,16 @@ def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
7878
try:
7979
image_loc = request.urlopen(image_loc).read()
8080
image = torchvision.io.decode_image(
81-
torch.frombuffer(image_loc, dtype=torch.uint8)
81+
torch.frombuffer(image_loc, dtype=torch.uint8),
82+
mode="RGB",
8283
)
8384
except Exception as e:
8485
raise ValueError("Failed to load remote image as torch.Tensor") from e
8586

8687
# Open the local image as a Tensor image
8788
else:
8889
try:
89-
image = torchvision.io.decode_image(image_loc)
90+
image = torchvision.io.decode_image(image_loc, mode="RGB")
9091
except Exception as e:
9192
raise ValueError("Failed to load local image as torch.Tensor") from e
9293
return image

0 commit comments

Comments
 (0)