Skip to content

Commit a3d386f

Browse files
minor changes addressing review comments
1 parent 768ccc3 commit a3d386f

5 files changed

Lines changed: 25 additions & 49 deletions

File tree

python/sparkdl/image/imageIO.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
4141
_OcvType = namedtuple("OcvType",["name","ord","nChannels","dtype"])
4242

43-
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
43+
4444
_supportedOcvTypes = (
4545
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8" ),
4646
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
@@ -50,18 +50,19 @@
5050
_OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"),
5151
)
5252

53-
__ocvTypesByName = {m.name:m for m in _supportedOcvTypes}
54-
__ocvTypesByOrdinal = {m.ord:m for m in _supportedOcvTypes}
53+
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
54+
_ocvTypesByName = {m.name:m for m in _supportedOcvTypes}
55+
_ocvTypesByOrdinal = {m.ord:m for m in _supportedOcvTypes}
5556

5657
def imageTypeByOrdinal(ord):
57-
if not ord in __ocvTypesByOrdinal:
58+
if not ord in _ocvTypesByOrdinal:
5859
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (ord,str(_supportedOcvTypes)))
59-
return __ocvTypesByOrdinal[ord]
60+
return _ocvTypesByOrdinal[ord]
6061

6162
def imageTypeByName(name):
62-
if not name in __ocvTypesByName:
63+
if not name in _ocvTypesByName:
6364
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (name,str(_supportedOcvTypes)))
64-
return __ocvTypesByName[name]
65+
return _ocvTypesByName[name]
6566

6667
def imageArrayToStruct(imgArray,origin=""):
6768
"""

python/tests/image/test_imageIO.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,6 @@ def _test(array):
111111
_test(np.random.randint(0, 256, (10, 11, nChannels), 'uint8'))
112112
_test(np.random.random_sample((10,11,nChannels)).astype('float32'))
113113

114-
def test_image_round_trip(self):
115-
# Test round trip: array -> png -> sparkImg -> array
116-
binarySchema = StructType([StructField("data", BinaryType(), False)])
117-
df = self.session.createDataFrame([[bytearray(pngData)]], binarySchema)
118-
119-
# Convert to images
120-
decImg = udf(lambda x:imageIO.imageArrayToStruct(imageIO.PIL_decode(x)), ImageSchema.imageSchema['image'].dataType)
121-
imageDF = df.select(decImg("data").alias("image"))
122-
row = imageDF.first()
123-
# array comes out of PIL and is in RGB order
124-
testArray = imageIO.imageStructToArray(row.image)
125-
self.assertEqual(testArray.shape, array.shape)
126-
self.assertEqual(testArray.dtype, array.dtype)
127-
self.assertTrue(np.all(array == testArray))
128-
129114
# read images now part of spark, no need to test it here
130115
def test_readImages(self):
131116
# Test that reading

python/tests/transformers/image_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,6 @@ def compareClassSets(self, preds1, preds2):
9292
self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]]))
9393

9494

95-
def getSampleImageList():
96-
imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*"))
97-
images = []
98-
for f in imageFiles:
99-
try:
100-
img = PIL.Image.open(f)
101-
except IOError:
102-
warn("Could not read file in image directory.")
103-
images.append(None)
104-
else:
105-
images.append(img)
106-
return imageFiles, images
107-
10895

10996
def executeKerasInceptionV3(image_df, uri_col="filePath"):
11097
"""

python/tests/transformers/named_image_test.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from sparkdl.image.image import ImageSchema
3131

3232
from ..tests import SparkDLTestCase
33-
from .image_utils import getSampleImageDF, getSampleImageList
34-
33+
from .image_utils import getSampleImageDF
3534

3635

3736
class KerasApplicationModelTestCase(SparkDLTestCase):
@@ -67,22 +66,29 @@ class NamedImageTransformerBaseTestCase(SparkDLTestCase):
6766
numPartitionsOverride = None
6867

6968

69+
@classmethod
70+
def getSampleImageList(cls):
71+
imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*"))
72+
images = []
73+
for f in imageFiles:
74+
try:
75+
img = PIL.Image.open(f)
76+
shape = cls.appModel.inputShape()
77+
images.append(imageIO._reverseChannels(np.array(img.resize(shape))))
78+
except IOError:
79+
warn("Could not read file in image directory.")
80+
images.append(None)
81+
return imageFiles, np.array(images)
82+
83+
7084
@classmethod
7185
def setUpClass(cls):
7286
super(NamedImageTransformerBaseTestCase, cls).setUpClass()
73-
7487
cls.appModel = keras_apps.getKerasApplicationModel(cls.name)
75-
shape = cls.appModel.inputShape()
76-
77-
imgFiles, images = getSampleImageList()
78-
imageArray = np.empty((len(images), shape[0], shape[1], 3), 'uint8')
79-
for i, img in enumerate(images):
80-
assert img is not None and img.mode == "RGB"
81-
imageArray[i] = imageIO._reverseChannels(np.array(img.resize(shape)))
88+
imgFiles, imageArray = getSampleImageList()
8289
cls.imageArray = imageArray
8390
cls.imgFiles = imgFiles
8491
cls.fileOrder = {imgFiles[i].split('/')[-1]:i for i in range(len(imgFiles))}
85-
8692
# Predict the class probabilities for the images in our test library using keras API
8793
# and cache for use by multiple tests.
8894
preppedImage = cls.appModel._testPreprocess(imageArray.astype('float32'))

python/tests/transformers/tf_image_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ class TFImageTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTe
3737
# Test loading & pre-processing as an example of a simple graph
3838
# NOTE: resizing here/tensorflow and in keras workflow are different, so the
3939
# test would fail with resizing added in.
40-
4140
def _loadImageViaKeras(self, raw_uri):
4241
uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri
4342
image = img_to_array(load_img(uri))
4443
image = np.expand_dims(image, axis=0)
4544
return preprocess_input(image)
4645

47-
# TODO: I believe this is already tested in named_image_test, should we remove it?
4846
def test_load_image_vs_keras(self):
4947
g = tf.Graph()
5048
with g.as_default():
@@ -70,7 +68,6 @@ def test_load_image_vs_keras(self):
7068
self.assertTrue( (processed == keras_processed).all() )
7169

7270

73-
# TODO: I believe this is already tested in named_image_test, should we remove it?
7471
def test_load_image_vs_keras_RGB(self):
7572
g = tf.Graph()
7673
with g.as_default():

0 commit comments

Comments
 (0)