|
30 | 30 | from sparkdl.image.image import ImageSchema |
31 | 31 |
|
32 | 32 | from ..tests import SparkDLTestCase |
33 | | -from .image_utils import getSampleImageDF, getSampleImageList |
34 | | - |
| 33 | +from .image_utils import getSampleImageDF |
35 | 34 |
|
36 | 35 |
|
37 | 36 | class KerasApplicationModelTestCase(SparkDLTestCase): |
@@ -67,22 +66,29 @@ class NamedImageTransformerBaseTestCase(SparkDLTestCase): |
67 | 66 | numPartitionsOverride = None |
68 | 67 |
|
69 | 68 |
|
| 69 | + @classmethod |
| 70 | + def getSampleImageList(cls): |
| 71 | + imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*")) |
| 72 | + images = [] |
| 73 | + for f in imageFiles: |
| 74 | + try: |
| 75 | + img = PIL.Image.open(f) |
| 76 | + shape = cls.appModel.inputShape() |
| 77 | + images.append(imageIO._reverseChannels(np.array(img.resize(shape)))) |
| 78 | + except IOError: |
| 79 | + warn("Could not read file in image directory.") |
| 80 | + images.append(None) |
| 81 | + return imageFiles, np.array(images) |
| 82 | + |
| 83 | + |
70 | 84 | @classmethod |
71 | 85 | def setUpClass(cls): |
72 | 86 | super(NamedImageTransformerBaseTestCase, cls).setUpClass() |
73 | | - |
74 | 87 | cls.appModel = keras_apps.getKerasApplicationModel(cls.name) |
75 | | - shape = cls.appModel.inputShape() |
76 | | - |
77 | | - imgFiles, images = getSampleImageList() |
78 | | - imageArray = np.empty((len(images), shape[0], shape[1], 3), 'uint8') |
79 | | - for i, img in enumerate(images): |
80 | | - assert img is not None and img.mode == "RGB" |
81 | | - imageArray[i] = imageIO._reverseChannels(np.array(img.resize(shape))) |
| 88 | + imgFiles, imageArray = getSampleImageList() |
82 | 89 | cls.imageArray = imageArray |
83 | 90 | cls.imgFiles = imgFiles |
84 | 91 | cls.fileOrder = {imgFiles[i].split('/')[-1]:i for i in range(len(imgFiles))} |
85 | | - |
86 | 92 | # Predict the class probabilities for the images in our test library using keras API |
87 | 93 | # and cache for use by multiple tests. |
88 | 94 | preppedImage = cls.appModel._testPreprocess(imageArray.astype('float32')) |
|
0 commit comments