Skip to content

Commit d981306

Browse files
committed
Improve error messages in load_image function for clarity
1 parent 91045e1 commit d981306

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

tests/torchtune/data/test_data_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def mock_urlopen(url):
122122
assert image.size() == (3, 403, 580)
123123

124124
# Test that a ValueError is raised when the image path is invalid
125-
with pytest.raises(ValueError, match="Failed to open image as torch.Tensor"):
125+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
126126
load_image("invalid_path")
127127

128128
# Test a temporary file with invalid image data
@@ -131,16 +131,16 @@ def mock_urlopen(url):
131131
f.write("Invalid image data")
132132

133133
# Test that a ValueError is raised when the image data is invalid
134-
with pytest.raises(ValueError, match="Failed to open image as torch.Tensor"):
134+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
135135
load_image(str(image_path))
136136

137137
# Test that a ValueError is raised when there is an HTTP error
138138
# Mock the urlopen function to raise an exception
139139
def mock_urlopen(url):
140-
raise Exception("Failed to load image")
140+
raise Exception("Failed to load remote image as torch.Tensor")
141141

142142
monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
143-
with pytest.raises(ValueError, match="Failed to load image"):
143+
with pytest.raises(ValueError, match="Failed to load remote image as torch.Tensor"):
144144
load_image("http://example.com/test_image.jpg")
145145

146146
# Test that a ValueError is raised when there is an IO error
@@ -149,7 +149,7 @@ def mock_urlopen(url):
149149
with open(image_path, "w") as f:
150150
f.write("Test data")
151151
os.chmod(image_path, 0o000) # Remove read permissions
152-
with pytest.raises(ValueError, match="Failed to open image as torch.Tensor"):
152+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
153153
load_image(str(image_path))
154154
os.chmod(image_path, 0o644) # Restore read permissions
155155

@@ -158,5 +158,5 @@ def mock_urlopen(url):
158158
image_path = tmp_path / "test_image.jpg"
159159
with open(image_path, "wb") as f:
160160
f.write(b"Invalid image data")
161-
with pytest.raises(ValueError, match="Failed to open image as torch.Tensor"):
161+
with pytest.raises(ValueError, match="Failed to load local image as torch.Tensor"):
162162
load_image(str(image_path))

torchtune/data/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def truncate(
4646
return tokens_truncated
4747

4848

49-
def load_image(image_loc: Union[Path, str]) -> "torch.Tensor":
49+
def load_image(image_loc: Union[Path, str]) -> torch.Tensor:
5050
"""
5151
Convenience method to load an image in torch.Tensor format from a local file path or remote source.
5252
@@ -81,14 +81,14 @@ def load_image(image_loc: Union[Path, str]) -> "torch.Tensor":
8181
torch.frombuffer(image_loc, dtype=torch.uint8)
8282
)
8383
except Exception as e:
84-
raise ValueError("Failed to load image as torch.Tensor") from e
84+
raise ValueError("Failed to load remote image as torch.Tensor") from e
8585

8686
# Open the local image as a Tensor image
8787
else:
8888
try:
8989
image = torchvision.io.decode_image(image_loc)
9090
except Exception as e:
91-
raise ValueError("Failed to open image as torch.Tensor") from e
91+
raise ValueError("Failed to load local image as torch.Tensor") from e
9292
return image
9393

9494

0 commit comments

Comments
 (0)