Skip to content

Commit 5736e45

Browse files
authored
Merge branch 'main' into patch-1
2 parents c260638 + fd9b61d commit 5736e45

File tree

8 files changed

+321
-16
lines changed

8 files changed

+321
-16
lines changed

test/test_image.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torchvision.transforms.functional as F
1111
from common_utils import assert_equal, needs_cuda
12-
from PIL import __version__ as PILLOW_VERSION, Image
12+
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
1313
from torchvision.io.image import (
1414
_read_png_16,
1515
decode_image,
@@ -100,6 +100,44 @@ def test_decode_jpeg(img_path, pil_mode, mode):
100100
assert abs_mean_diff < 2
101101

102102

103+
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
104+
def test_decode_jpeg_with_exif_orientation(tmpdir, orientation):
105+
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg")
106+
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
107+
im = F.to_pil_image(t)
108+
exif = im.getexif()
109+
exif[0x0112] = orientation # set exif orientation
110+
im.save(fp, "JPEG", exif=exif.tobytes())
111+
112+
data = read_file(fp)
113+
output = decode_image(data, apply_exif_orientation=True)
114+
115+
pimg = Image.open(fp)
116+
pimg = ImageOps.exif_transpose(pimg)
117+
118+
expected = F.pil_to_tensor(pimg)
119+
torch.testing.assert_close(expected, output)
120+
121+
122+
@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
123+
def test_invalid_exif(tmpdir, size):
124+
# Inspired from a PIL test:
125+
# https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
126+
fp = os.path.join(tmpdir, "invalid_exif.jpg")
127+
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
128+
im = F.to_pil_image(t)
129+
im.save(fp, "JPEG", exif=b"1" * size)
130+
131+
data = read_file(fp)
132+
output = decode_image(data, apply_exif_orientation=True)
133+
134+
pimg = Image.open(fp)
135+
pimg = ImageOps.exif_transpose(pimg)
136+
137+
expected = F.pil_to_tensor(pimg)
138+
torch.testing.assert_close(expected, output)
139+
140+
103141
def test_decode_jpeg_errors():
104142
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
105143
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

torchvision/csrc/io/image/cpu/decode_image.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
namespace vision {
77
namespace image {
88

9-
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
9+
torch::Tensor decode_image(
10+
const torch::Tensor& data,
11+
ImageReadMode mode,
12+
bool apply_exif_orientation) {
1013
// Check that tensor is a CPU tensor
1114
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
1215
// Check that the input tensor dtype is uint8
@@ -22,7 +25,7 @@ torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
2225
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
2326

2427
if (memcmp(jpeg_signature, datap, 3) == 0) {
25-
return decode_jpeg(data, mode);
28+
return decode_jpeg(data, mode, apply_exif_orientation);
2629
} else if (memcmp(png_signature, datap, 4) == 0) {
2730
return decode_png(data, mode);
2831
} else {

torchvision/csrc/io/image/cpu/decode_image.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ namespace image {
88

99
C10_EXPORT torch::Tensor decode_image(
1010
const torch::Tensor& data,
11-
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
12+
bool apply_exif_orientation = false);
1213

1314
} // namespace image
1415
} // namespace vision

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
#include "decode_jpeg.h"
22
#include "common_jpeg.h"
3+
#include "exif.h"
34

45
namespace vision {
56
namespace image {
67

78
#if !JPEG_FOUND
8-
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
9+
torch::Tensor decode_jpeg(
10+
const torch::Tensor& data,
11+
ImageReadMode mode,
12+
bool apply_exif_orientation) {
913
TORCH_CHECK(
1014
false, "decode_jpeg: torchvision not compiled with libjpeg support");
1115
}
1216
#else
1317

1418
using namespace detail;
19+
using namespace exif_private;
1520

1621
namespace {
1722

@@ -65,6 +70,8 @@ static void torch_jpeg_set_source_mgr(
6570
src->len = len;
6671
src->pub.bytes_in_buffer = len;
6772
src->pub.next_input_byte = src->data;
73+
74+
jpeg_save_markers(cinfo, APP1, 0xffff);
6875
}
6976

7077
inline unsigned char clamped_cmyk_rgb_convert(
@@ -121,7 +128,10 @@ void convert_line_cmyk_to_gray(
121128

122129
} // namespace
123130

124-
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
131+
torch::Tensor decode_jpeg(
132+
const torch::Tensor& data,
133+
ImageReadMode mode,
134+
bool apply_exif_orientation) {
125135
C10_LOG_API_USAGE_ONCE(
126136
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
127137
// Check that the input tensor dtype is uint8
@@ -191,6 +201,11 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
191201
jpeg_calc_output_dimensions(&cinfo);
192202
}
193203

204+
int exif_orientation = -1;
205+
if (apply_exif_orientation) {
206+
exif_orientation = fetch_exif_orientation(&cinfo);
207+
}
208+
194209
jpeg_start_decompress(&cinfo);
195210

196211
int height = cinfo.output_height;
@@ -227,7 +242,12 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
227242

228243
jpeg_finish_decompress(&cinfo);
229244
jpeg_destroy_decompress(&cinfo);
230-
return tensor.permute({2, 0, 1});
245+
auto output = tensor.permute({2, 0, 1});
246+
247+
if (apply_exif_orientation) {
248+
return exif_orientation_transform(output, exif_orientation);
249+
}
250+
return output;
231251
}
232252
#endif // #if !JPEG_FOUND
233253

torchvision/csrc/io/image/cpu/decode_jpeg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ namespace image {
88

99
C10_EXPORT torch::Tensor decode_jpeg(
1010
const torch::Tensor& data,
11-
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
12+
bool apply_exif_orientation = false);
1213

1314
C10_EXPORT int64_t _jpeg_version();
1415
C10_EXPORT bool _is_compiled_against_turbo();
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*M///////////////////////////////////////////////////////////////////////////////////////
2+
//
3+
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4+
//
5+
// By downloading, copying, installing or using the software you agree to this
6+
license.
7+
// If you do not agree to this license, do not download, install,
8+
// copy or use the software.
9+
//
10+
//
11+
// License Agreement
12+
// For Open Source Computer Vision Library
13+
//
14+
// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
15+
// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
16+
// Third party copyrights are property of their respective owners.
17+
//
18+
// Redistribution and use in source and binary forms, with or without
19+
modification,
20+
// are permitted provided that the following conditions are met:
21+
//
22+
// * Redistribution's of source code must retain the above copyright notice,
23+
// this list of conditions and the following disclaimer.
24+
//
25+
// * Redistribution's in binary form must reproduce the above copyright
26+
notice,
27+
// this list of conditions and the following disclaimer in the documentation
28+
// and/or other materials provided with the distribution.
29+
//
30+
// * The name of the copyright holders may not be used to endorse or promote
31+
products
32+
// derived from this software without specific prior written permission.
33+
//
34+
// This software is provided by the copyright holders and contributors "as is"
35+
and
36+
// any express or implied warranties, including, but not limited to, the implied
37+
// warranties of merchantability and fitness for a particular purpose are
38+
disclaimed.
39+
// In no event shall the Intel Corporation or contributors be liable for any
40+
direct,
41+
// indirect, incidental, special, exemplary, or consequential damages
42+
// (including, but not limited to, procurement of substitute goods or services;
43+
// loss of use, data, or profits; or business interruption) however caused
44+
// and on any theory of liability, whether in contract, strict liability,
45+
// or tort (including negligence or otherwise) arising in any way out of
46+
// the use of this software, even if advised of the possibility of such damage.
47+
//
48+
//M*/
49+
#pragma once
50+
// Functions in this module are taken from OpenCV
51+
// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp
52+
53+
#if JPEG_FOUND
54+
55+
#include <jpeglib.h>
56+
#include <torch/types.h>
57+
58+
namespace vision {
59+
namespace image {
60+
namespace exif_private {
61+
62+
constexpr uint16_t APP1 = 0xe1;
63+
constexpr uint16_t ENDIANNESS_INTEL = 0x49;
64+
constexpr uint16_t ENDIANNESS_MOTO = 0x4d;
65+
constexpr uint16_t REQ_EXIF_TAG_MARK = 0x2a;
66+
constexpr uint16_t ORIENTATION_EXIF_TAG = 0x0112;
67+
constexpr uint16_t INCORRECT_TAG = -1;
68+
69+
class ExifDataReader {
70+
public:
71+
ExifDataReader(unsigned char* p, size_t s) : _ptr(p), _size(s) {}
72+
size_t size() const {
73+
return _size;
74+
}
75+
const unsigned char& operator[](size_t index) const {
76+
TORCH_CHECK(index >= 0 && index < _size);
77+
return _ptr[index];
78+
}
79+
80+
protected:
81+
unsigned char* _ptr;
82+
size_t _size;
83+
};
84+
85+
inline uint16_t get_endianness(const ExifDataReader& exif_data) {
86+
if ((exif_data.size() < 1) ||
87+
(exif_data.size() > 1 && exif_data[0] != exif_data[1])) {
88+
return 0;
89+
}
90+
if (exif_data[0] == 'I') {
91+
return ENDIANNESS_INTEL;
92+
}
93+
if (exif_data[0] == 'M') {
94+
return ENDIANNESS_MOTO;
95+
}
96+
return 0;
97+
}
98+
99+
inline uint16_t get_uint16(
100+
const ExifDataReader& exif_data,
101+
uint16_t endianness,
102+
const size_t offset) {
103+
if (offset + 1 >= exif_data.size()) {
104+
return INCORRECT_TAG;
105+
}
106+
107+
if (endianness == ENDIANNESS_INTEL) {
108+
return exif_data[offset] + (exif_data[offset + 1] << 8);
109+
}
110+
return (exif_data[offset] << 8) + exif_data[offset + 1];
111+
}
112+
113+
inline uint32_t get_uint32(
114+
const ExifDataReader& exif_data,
115+
uint16_t endianness,
116+
const size_t offset) {
117+
if (offset + 3 >= exif_data.size()) {
118+
return INCORRECT_TAG;
119+
}
120+
121+
if (endianness == ENDIANNESS_INTEL) {
122+
return exif_data[offset] + (exif_data[offset + 1] << 8) +
123+
(exif_data[offset + 2] << 16) + (exif_data[offset + 3] << 24);
124+
}
125+
return (exif_data[offset] << 24) + (exif_data[offset + 1] << 16) +
126+
(exif_data[offset + 2] << 8) + exif_data[offset + 3];
127+
}
128+
129+
inline int fetch_exif_orientation(j_decompress_ptr cinfo) {
130+
int exif_orientation = -1;
131+
// Check for Exif marker APP1
132+
jpeg_saved_marker_ptr exif_marker = 0;
133+
jpeg_saved_marker_ptr cmarker = cinfo->marker_list;
134+
while (cmarker && exif_marker == 0) {
135+
if (cmarker->marker == APP1) {
136+
exif_marker = cmarker;
137+
}
138+
cmarker = cmarker->next;
139+
}
140+
141+
if (exif_marker) {
142+
// Exif binary structure looks like this
143+
// First 6 bytes: [E, x, i, f, 0, 0]
144+
// Endianness, 2 bytes : [M, M] or [I, I]
145+
// Tag mark, 2 bytes: [0, 0x2a]
146+
// Offset, 4 bytes
147+
// Num entries, 2 bytes
148+
// Tag entries and data, tag has 2 bytes and its data has 10 bytes
149+
// For more details:
150+
// http://www.media.mit.edu/pia/Research/deepview/exif.html
151+
152+
// Bytes from Exif size field to the first TIFF header
153+
constexpr size_t start_offset = 6;
154+
if (exif_marker->data_length > start_offset) {
155+
auto* exif_data_ptr = exif_marker->data + start_offset;
156+
auto size = exif_marker->data_length - start_offset;
157+
158+
ExifDataReader exif_data(exif_data_ptr, size);
159+
auto endianness = get_endianness(exif_data);
160+
161+
// Checking whether Tag Mark (0x002A) correspond to one contained in the
162+
// Jpeg file
163+
uint16_t tag_mark = get_uint16(exif_data, endianness, 2);
164+
if (tag_mark == REQ_EXIF_TAG_MARK) {
165+
auto offset = get_uint32(exif_data, endianness, 4);
166+
size_t num_entry = get_uint16(exif_data, endianness, offset);
167+
offset += 2; // go to start of tag fields
168+
constexpr size_t tiff_field_size = 12;
169+
for (size_t entry = 0; entry < num_entry; entry++) {
170+
// Here we just search for orientation tag and parse it
171+
auto tag_num = get_uint16(exif_data, endianness, offset);
172+
if (tag_num == INCORRECT_TAG) {
173+
break;
174+
}
175+
if (tag_num == ORIENTATION_EXIF_TAG) {
176+
exif_orientation = get_uint16(exif_data, endianness, offset + 8);
177+
break;
178+
}
179+
offset += tiff_field_size;
180+
}
181+
}
182+
}
183+
}
184+
return exif_orientation;
185+
}
186+
187+
constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation
188+
constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip
189+
constexpr uint16_t IMAGE_ORIENTATION_BR = 3; // needs 180 rotation
190+
constexpr uint16_t IMAGE_ORIENTATION_BL = 4; // needs vertical flip
191+
constexpr uint16_t IMAGE_ORIENTATION_LT =
192+
5; // mirrored horizontal & rotate 270 CW
193+
constexpr uint16_t IMAGE_ORIENTATION_RT = 6; // rotate 90 CW
194+
constexpr uint16_t IMAGE_ORIENTATION_RB =
195+
7; // mirrored horizontal & rotate 90 CW
196+
constexpr uint16_t IMAGE_ORIENTATION_LB = 8; // needs 270 CW rotation
197+
198+
inline torch::Tensor exif_orientation_transform(
199+
const torch::Tensor& image,
200+
int orientation) {
201+
if (orientation == IMAGE_ORIENTATION_TL) {
202+
return image;
203+
} else if (orientation == IMAGE_ORIENTATION_TR) {
204+
return image.flip(-1);
205+
} else if (orientation == IMAGE_ORIENTATION_BR) {
206+
// needs 180 rotation equivalent to
207+
// flip both horizontally and vertically
208+
return image.flip({-2, -1});
209+
} else if (orientation == IMAGE_ORIENTATION_BL) {
210+
return image.flip(-2);
211+
} else if (orientation == IMAGE_ORIENTATION_LT) {
212+
return image.transpose(-1, -2);
213+
} else if (orientation == IMAGE_ORIENTATION_RT) {
214+
return image.transpose(-1, -2).flip(-1);
215+
} else if (orientation == IMAGE_ORIENTATION_RB) {
216+
return image.transpose(-1, -2).flip({-2, -1});
217+
} else if (orientation == IMAGE_ORIENTATION_LB) {
218+
return image.transpose(-1, -2).flip(-2);
219+
}
220+
return image;
221+
}
222+
223+
} // namespace exif_private
224+
} // namespace image
225+
} // namespace vision
226+
227+
#endif

0 commit comments

Comments
 (0)