diff --git a/README.md b/README.md index 13923fe5..3f57f748 100644 --- a/README.md +++ b/README.md @@ -80,14 +80,23 @@ To try running the examples below, check out the Databricks notebook [DeepLearni ### Working with images in Spark -The first step to applying deep learning on images is the ability to load the images. Deep Learning Pipelines includes utility functions that can load millions of images into a Spark DataFrame and decode them automatically in a distributed fashion, allowing manipulation at scale. +The first step to applying deep learning on images is the ability to load the images. Spark and Deep Learning Pipelines include utility functions that can load millions of images into a Spark DataFrame and decode them automatically in a distributed fashion, allowing manipulation at scale. + +Using Spark's ImageSchema + +```python +from sparkdl.image.image import ImageSchema +image_df = ImageSchema.readImages("/data/myimages") +``` + +or if custom image library is needed: ```python -from sparkdl import readImages -image_df = readImages("/data/myimages") +from sparkdl.image import imageIO as imageIO +image_df = imageIO.readImagesWithCustomFn("/data/myimages",decode_f=) ``` -The resulting DataFrame contains a string column named "filePath" containing the path to each image file, and a image struct ("`SpImage`") column named "image" containing the decoded image data. +The resulting DataFrame contains a string column named "image" containing an image struct with schema == ImageSchema. ```python image_df.show() @@ -109,7 +118,7 @@ featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", modelNa lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label") p = Pipeline(stages=[featurizer, lr]) -model = p.fit(train_images_df) # train_images_df is a dataset of images (SpImage) and labels +model = p.fit(train_images_df) # train_images_df is a dataset of images and labels # Inspect training error df = model.transform(train_images_df.limit(10)).select("image", "probability", "uri", "label") @@ -127,11 +136,13 @@ Spark DataFrames are a natural construct for applying deep learning models to a There are many well-known deep learning models for images. If the task at hand is very similar to what the models provide (e.g. object recognition with ImageNet classes), or for pure exploration, one can use the Transformer `DeepImagePredictor` by simply specifying the model name. ```python - from sparkdl import readImages, DeepImagePredictor + from sparkdl.image.image import ImageSchema + + from sparkdl import DeepImagePredictor predictor = DeepImagePredictor(inputCol="image", outputCol="predicted_labels", modelName="InceptionV3", decodePredictions=True, topK=10) - image_df = readImages("/data/myimages") + image_df = ImageSchema.readImages("/data/myimages") predictions_df = predictor.transform(image_df) ``` @@ -140,7 +151,8 @@ Spark DataFrames are a natural construct for applying deep learning models to a Deep Learning Pipelines provides a Transformer that will apply the given TensorFlow Graph to a DataFrame containing a column of images (e.g. loaded using the utilities described in the previous section). Here is a very simple example of how a TensorFlow Graph can be used with the Transformer. In practice, the TensorFlow Graph will likely be restored from files before calling `TFImageTransformer`. ```python - from sparkdl import readImages, TFImageTransformer + from sparkdl.image.image import ImageSchema + from sparkdl import TFImageTransformer import sparkdl.graph.utils as tfx from sparkdl.transformers import utils import tensorflow as tf @@ -155,7 +167,7 @@ Spark DataFrames are a natural construct for applying deep learning models to a transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph, inputTensor=image_arr, outputTensor=resized_images, outputMode="image") - image_df = readImages("/data/myimages") + image_df = ImageSchema.readImages("/data/myimages") processed_image_df = transformer.transform(image_df) ``` diff --git a/build.sbt b/build.sbt index ce51837f..ba53e386 100644 --- a/build.sbt +++ b/build.sbt @@ -35,7 +35,7 @@ sparkComponents ++= Seq("mllib-local", "mllib", "sql") // add any Spark Package dependencies using spDependencies. // e.g. spDependencies += "databricks/spark-avro:0.1" spDependencies += s"databricks/tensorframes:0.2.9-s_${scalaMajorVersion}" -spDependencies += "Microsoft/spark-images:0.1" + // These versions are ancient, but they cross-compile around scala 2.10 and 2.11. // Update them when dropping support for scala 2.10 diff --git a/project/plugins.sbt b/project/plugins.sbt index 34b7486e..e5cd848d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,8 +1,5 @@ // You may use this file to add plugin dependencies for sbt. resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/" - addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5") - // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits") - addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") diff --git a/python/sparkdl/__init__.py b/python/sparkdl/__init__.py index c05b9dfc..228ce6c5 100644 --- a/python/sparkdl/__init__.py +++ b/python/sparkdl/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from .graph.input import TFInputGraph -from .image.imageIO import imageSchema, imageType, readImages from .transformers.keras_image import KerasImageFileTransformer from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer from .transformers.tf_image import TFImageTransformer diff --git a/python/sparkdl/estimators/keras_image_file_estimator.py b/python/sparkdl/estimators/keras_image_file_estimator.py index 6dcb1ef1..1d67ed6b 100644 --- a/python/sparkdl/estimators/keras_image_file_estimator.py +++ b/python/sparkdl/estimators/keras_image_file_estimator.py @@ -36,6 +36,7 @@ logger = logging.getLogger('sparkdl') + class KerasImageFileEstimator(Estimator, HasInputCol, HasInputImageNodeName, HasOutputCol, HasOutputNodeName, HasLabelCol, HasKerasModel, HasKerasOptimizer, HasKerasLoss, diff --git a/python/sparkdl/graph/builder.py b/python/sparkdl/graph/builder.py index a7d7122f..c8f3ce4b 100644 --- a/python/sparkdl/graph/builder.py +++ b/python/sparkdl/graph/builder.py @@ -27,6 +27,7 @@ logger = logging.getLogger('sparkdl') + class IsolatedSession(object): """ Provide an isolated session to work with mixed Keras and TensorFlow @@ -43,6 +44,7 @@ class IsolatedSession(object): In this case, all Keras models loaded in this session will be accessible as a subgraph of of `graph` """ + def __init__(self, graph=None, using_keras=False): self.graph = graph or tf.Graph() self.sess = tf.Session(graph=self.graph) @@ -166,7 +168,7 @@ def _fromKerasModelFile(cls, file_path): 'Keras model must be specified as HDF5 file' with IsolatedSession(using_keras=True) as issn: - K.set_learning_phase(0) # Testing phase + K.set_learning_phase(0) # Testing phase model = load_model(file_path) gfn = issn.asGraphFunction(model.inputs, model.outputs) @@ -223,7 +225,8 @@ def fromList(cls, functions): # We currently only support single input/output for intermediary stages # The functions could still take multi-dimensional tensor, but only one if len(gfn_out.input_names) != 1: - raise NotImplementedError("Only support single input/output for intermediary layers") + raise NotImplementedError( + "Only support single input/output for intermediary layers") # Acquire initial placeholders' properties # We want the input names of the merged function are not under scoped diff --git a/python/sparkdl/graph/input.py b/python/sparkdl/graph/input.py index 67ab1119..00df2b6f 100644 --- a/python/sparkdl/graph/input.py +++ b/python/sparkdl/graph/input.py @@ -23,6 +23,7 @@ # pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring + class TFInputGraph(object): """ An opaque object containing TensorFlow graph. @@ -84,7 +85,6 @@ class TFInputGraph(object): Please see the example above. """ - def __init__(self, graph_def, input_tensor_name_from_signature, output_tensor_name_from_signature): self.graph_def = graph_def @@ -281,6 +281,7 @@ def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_n return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, fetch_names=fetch_names) + def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names): """ Construct a TFInputGraph from a SavedModel. diff --git a/python/sparkdl/graph/pieces.py b/python/sparkdl/graph/pieces.py index 4630835a..f282fb27 100644 --- a/python/sparkdl/graph/pieces.py +++ b/python/sparkdl/graph/pieces.py @@ -18,7 +18,7 @@ import tensorflow as tf from sparkdl.graph.builder import IsolatedSession -from sparkdl.image.imageIO import SparkMode +from sparkdl.image import imageIO logger = logging.getLogger('sparkdl') @@ -29,7 +29,8 @@ Deserializing ProtocolBuffer bytes is in general faster than directly loading Keras models. """ -def buildSpImageConverter(img_dtype): + +def buildSpImageConverter(channelOrder, img_dtype): """ Convert a imageIO byte encoded image into a image tensor suitable as input to ConvNets The name of the input must be a subset of those specified in `image.imageIO.imageSchema`. @@ -48,23 +49,25 @@ def buildSpImageConverter(img_dtype): # This is the default behavior of Python Image Library shape = tf.reshape(tf.stack([height, width, num_channels], axis=0), shape=(3,), name='shape') - if img_dtype == SparkMode.RGB: + if img_dtype == 'uint8': image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw") image_float = tf.to_float(image_uint8) - else: - assert img_dtype == SparkMode.RGB_FLOAT32, \ - "Unsupported dtype for image: {}".format(img_dtype) + elif img_dtype == 'float32': image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw") - + else: + raise ValueError( + 'unsupported image data type "%s", currently only know how to handle uint8 and float32' % img_dtype) image_reshaped = tf.reshape(image_float, shape, name="reshaped") + image_reshaped = imageIO.fixColorChannelOrdering(channelOrder, image_reshaped) image_input = tf.expand_dims(image_reshaped, 0, name="image_input") gfn = issn.asGraphFunction([height, width, image_buffer, num_channels], [image_input]) return gfn + def buildFlattener(): - """ - Build a flattening layer to remove the extra leading tensor dimension. + """ + Build a flattening layer to remove the extra leading tensor dimension. e.g. a tensor of shape [1, W, H, C] will have a shape [W, H, C] after applying this. """ with IsolatedSession() as issn: diff --git a/python/sparkdl/graph/tensorframes_udf.py b/python/sparkdl/graph/tensorframes_udf.py index aa1531b4..165917a8 100644 --- a/python/sparkdl/graph/tensorframes_udf.py +++ b/python/sparkdl/graph/tensorframes_udf.py @@ -23,6 +23,7 @@ logger = logging.getLogger('sparkdl') + def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=False, register=True): """ Create a Spark SQL UserDefinedFunction from a given TensorFlow Graph diff --git a/python/sparkdl/graph/utils.py b/python/sparkdl/graph/utils.py index cfac1580..020fb655 100644 --- a/python/sparkdl/graph/utils.py +++ b/python/sparkdl/graph/utils.py @@ -31,6 +31,7 @@ one of the four target variants. """ + def validated_graph(graph): """ Check if the input is a valid :py:class:`tf.Graph` and return it. @@ -41,6 +42,7 @@ def validated_graph(graph): assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph)) return graph + def get_shape(tfobj_or_name, graph): """ Return the shape of the tensor as a list @@ -52,6 +54,7 @@ def get_shape(tfobj_or_name, graph): _shape = get_tensor(tfobj_or_name, graph).get_shape().as_list() return [-1 if x is None else x for x in _shape] + def get_op(tfobj_or_name, graph): """ Get a :py:class:`tf.Operation` object. @@ -76,6 +79,7 @@ def get_op(tfobj_or_name, graph): assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op) return op + def get_tensor(tfobj_or_name, graph): """ Get a :py:class:`tf.Tensor` object @@ -100,6 +104,7 @@ def get_tensor(tfobj_or_name, graph): assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr) return tnsr + def tensor_name(tfobj_or_name, graph=None): """ Derive the :py:class:`tf.Tensor` name from a :py:class:`tf.Operation` or :py:class:`tf.Tensor` @@ -130,6 +135,7 @@ def tensor_name(tfobj_or_name, graph=None): else: raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) + def op_name(tfobj_or_name, graph=None): """ Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or @@ -158,9 +164,11 @@ def op_name(tfobj_or_name, graph=None): else: raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) + def add_scope_to_name(scope, name): """ Prepends the provided scope to the passed-in op or tensor name. """ - return "%s/%s"%(scope, name) + return "%s/%s" % (scope, name) + def validated_output(tfobj_or_name, graph): """ @@ -172,6 +180,7 @@ def validated_output(tfobj_or_name, graph): graph = validated_graph(graph) return op_name(tfobj_or_name, graph) + def validated_input(tfobj_or_name, graph): """ Validate and return the input names useable GraphFunction @@ -186,6 +195,7 @@ def validated_input(tfobj_or_name, graph): ('input must be Placeholder, but get', op.type) return name + def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): """ Create a static view of the graph by diff --git a/python/sparkdl/image/image.py b/python/sparkdl/image/image.py new file mode 100644 index 00000000..2569abdb --- /dev/null +++ b/python/sparkdl/image/image.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# NOTE: This file is copied from Spark2.3 in order to be able to use this in already released spark versions. +# TODO: remove this when Spark 2.3 is out! + +""" +.. attribute:: ImageSchema + + An attribute of this module that contains the instance of :class:`_ImageSchema`. + +.. autoclass:: _ImageSchema + :members: +""" + +import numpy as np +from pyspark import SparkContext +from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string +from pyspark.sql import DataFrame, SparkSession + + +class _ImageSchema(object): + """ + Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be private and + not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to access the + APIs of this class. + """ + + def __init__(self): + self._imageSchema = None + self._ocvTypes = None + self._imageFields = None + self._undefinedImageType = None + + @property + def imageSchema(self): + """ + Returns the image schema. + + :return: a :class:`StructType` with a single column of images + named "image" (nullable). + + .. versionadded:: 2.3.0 + """ + + if self._imageSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() + self._imageSchema = _parse_datatype_json_string(jschema.json()) + return self._imageSchema + + @property + def ocvTypes(self): + """ + Returns the OpenCV type mapping supported. + + :return: a dictionary containing the OpenCV type mapping supported. + + .. versionadded:: 2.3.0 + """ + + if self._ocvTypes is None: + ctx = SparkContext._active_spark_context + self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) + return self._ocvTypes + + @property + def imageFields(self): + """ + Returns field names of image columns. + + :return: a list of field names. + + .. versionadded:: 2.3.0 + """ + + if self._imageFields is None: + ctx = SparkContext._active_spark_context + self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) + return self._imageFields + + @property + def undefinedImageType(self): + """ + Returns the name of undefined image type for the invalid image. + + .. versionadded:: 2.3.0 + """ + + if self._undefinedImageType is None: + ctx = SparkContext._active_spark_context + self._undefinedImageType = \ + ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() + return self._undefinedImageType + + def toNDArray(self, image): + """ + Converts an image to an array with metadata. + + :param `Row` image: A row that contains the image to be converted. It should + have the attributes specified in `ImageSchema.imageSchema`. + :return: a `numpy.ndarray` that is an image. + + .. versionadded:: 2.3.0 + """ + + if not isinstance(image, Row): + raise TypeError( + "image argument should be pyspark.sql.types.Row; however, " + "it got [%s]." % type(image)) + + if any(not hasattr(image, f) for f in self.imageFields): + raise ValueError( + "image argument should have attributes specified in " + "ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields)) + + height = image.height + width = image.width + nChannels = image.nChannels + return np.ndarray( + shape=(height, width, nChannels), + dtype=np.uint8, + buffer=image.data, + strides=(width * nChannels, nChannels, 1)) + + def toImage(self, array, origin=""): + """ + Converts an array with metadata to a two-dimensional image. + + :param `numpy.ndarray` array: The array to convert to image. + :param str origin: Path to the image, optional. + :return: a :class:`Row` that is a two dimensional image. + + .. versionadded:: 2.3.0 + """ + + if not isinstance(array, np.ndarray): + raise TypeError( + "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) + + if array.ndim != 3: + raise ValueError("Invalid array shape") + + height, width, nChannels = array.shape + ocvTypes = ImageSchema.ocvTypes + if nChannels == 1: + mode = ocvTypes["CV_8UC1"] + elif nChannels == 3: + mode = ocvTypes["CV_8UC3"] + elif nChannels == 4: + mode = ocvTypes["CV_8UC4"] + else: + raise ValueError("Invalid number of channels") + + # Running `bytearray(numpy.array([1]))` fails in specific Python versions + # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. + # Here, it avoids it by converting it to bytes. + data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + + # Creating new Row with _create_row(), because Row(name = value, ... ) + # orders fields by name, which conflicts with expected schema order + # when the new DataFrame is created by UDF + return _create_row(self.imageFields, + [origin, height, width, nChannels, mode, data]) + + def readImages(self, path, recursive=False, numPartitions=-1, + dropImageFailures=False, sampleRatio=1.0, seed=0): + """ + Reads the directory of images from the local or remote source. + + .. note:: If multiple jobs are run in parallel with different sampleRatio or recursive flag, + there may be a race condition where one job overwrites the hadoop configs of another. + + .. note:: If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + potentially non-deterministic. + + :param str path: Path to the image directory. + :param bool recursive: Recursive search flag. + :param int numPartitions: Number of DataFrame partitions. + :param bool dropImageFailures: Drop the files that are not valid images. + :param float sampleRatio: Fraction of the images loaded. + :param int seed: Random number seed. + :return: a :class:`DataFrame` with a single column of "images", + see ImageSchema for details. + + >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df.count() + 4 + + .. versionadded:: 2.3.0 + """ + + ctx = SparkContext._active_spark_context + spark = SparkSession(ctx) + image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + jsession = spark._jsparkSession + jresult = image_schema.readImages(path, jsession, recursive, numPartitions, + dropImageFailures, float(sampleRatio), seed) + return DataFrame(jresult, spark._wrapped) + + +ImageSchema = _ImageSchema() + + +# Monkey patch to disallow instantization of this class. +def _disallow_instance(_): + raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") + + +_ImageSchema.__init__ = _disallow_instance diff --git a/python/sparkdl/image/imageIO.py b/python/sparkdl/image/imageIO.py index 03c101cc..a2a60145 100644 --- a/python/sparkdl/image/imageIO.py +++ b/python/sparkdl/image/imageIO.py @@ -15,7 +15,6 @@ from io import BytesIO from collections import namedtuple -from warnings import warn # 3rd party import numpy as np @@ -24,91 +23,69 @@ # pyspark from pyspark import Row from pyspark import SparkContext -from pyspark.sql.types import (BinaryType, IntegerType, StringType, StructField, StructType) +from sparkdl.image.image import ImageSchema from pyspark.sql.functions import udf +from pyspark.sql.types import ( + BinaryType, IntegerType, StringType, StructField, StructType) -imageSchema = StructType([StructField("mode", StringType(), False), - StructField("height", IntegerType(), False), - StructField("width", IntegerType(), False), - StructField("nChannels", IntegerType(), False), - StructField("data", BinaryType(), False)]) - - -# ImageType class for holding metadata about images stored in DataFrames. +# ImageType represents supported OpenCV types # fields: +# name - OpenCvMode +# ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema). # nChannels - number of channels in the image -# dtype - data type of the image's "data" Column, sorted as a numpy compatible string. -# channelContent - info about the contents of each channel currently only "I" (intensity) and -# "RGB" are supported for 1 and 3 channel data respectively. -# pilMode - The mode that should be used to convert to a PIL image. -# sparkMode - Unique identifier string used in spark image representation. -ImageType = namedtuple("ImageType", ["nChannels", - "dtype", - "channelContent", - "pilMode", - "sparkMode", - ]) -class SparkMode(object): - RGB = "RGB" - FLOAT32 = "float32" - RGB_FLOAT32 = "RGB-float32" - -supportedImageTypes = [ - ImageType(3, "uint8", "RGB", "RGB", SparkMode.RGB), - ImageType(1, "float32", "I", "F", SparkMode.FLOAT32), - ImageType(3, "float32", "RGB", None, SparkMode.RGB_FLOAT32), -] -pilModeLookup = {t.pilMode: t for t in supportedImageTypes - if t.pilMode is not None} -sparkModeLookup = {t.sparkMode: t for t in supportedImageTypes} - - -def imageArrayToStruct(imgArray, sparkMode=None): - """ - Create a row representation of an image from an image array and (optional) imageType. +# dtype - data type of the image's array, sorted as a numpy compatible string. +# +# NOTE: likely to be migrated to Spark ImageSchema code in the near future. +_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"]) - to_image_udf = udf(arrayToImageRow, imageSchema) - df.withColumn("output_img", to_image_udf(df["np_arr_col"]) - :param imgArray: ndarray, image data. - :param sparkMode: spark mode, type information for the image, will be inferred from array if - the mode is not provide. See SparkMode for valid modes. - :return: Row, image as a DataFrame Row. - """ - # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists. - if len(imgArray.shape) == 4: - if imgArray.shape[0] != 1: - raise ValueError("The first dimension of a 4-d image array is expected to be 1.") - imgArray = imgArray.reshape(imgArray.shape[1:]) +_supportedOcvTypes = ( + _OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"), + _OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"), + _OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"), + _OcvType(name="CV_32FC3", ord=21, nChannels=3, dtype="float32"), + _OcvType(name="CV_8UC4", ord=24, nChannels=4, dtype="uint8"), + _OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"), +) - if sparkMode is None: - sparkMode = _arrayToSparkMode(imgArray) - imageType = sparkModeLookup[sparkMode] +# NOTE: likely to be migrated to Spark ImageSchema code in the near future. +_ocvTypesByName = {m.name: m for m in _supportedOcvTypes} +_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes} - height, width, nChannels = imgArray.shape - if imageType.nChannels != nChannels: - msg = "Image of type {} should have {} channels, but array has {} channels." - raise ValueError(msg.format(sparkMode, imageType.nChannels, nChannels)) - # Convert the array to match the image type. - if not np.can_cast(imgArray, imageType.dtype, 'same_kind'): - msg = "Array of type {} cannot safely be cast to image type {}." - raise ValueError(msg.format(imgArray.dtype, imageType.dtype)) - imgArray = np.array(imgArray, dtype=imageType.dtype, copy=False) +def imageTypeByOrdinal(ord): + if not ord in _ocvTypesByOrdinal: + raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % ( + ord, str(_supportedOcvTypes))) + return _ocvTypesByOrdinal[ord] - data = bytearray(imgArray.tobytes()) - return Row(mode=sparkMode, height=height, width=width, nChannels=nChannels, data=data) +def imageTypeByName(name): + if not name in _ocvTypesByName: + raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % ( + name, str(_supportedOcvTypes))) + return _ocvTypesByName[name] -def imageType(imageRow): + +def imageArrayToStruct(imgArray, origin=""): """ - Get type information about the image. + Create a row representation of an image from an image array. - :param imageRow: spark image row. - :return: ImageType + :param imgArray: ndarray, image data. + :return: Row, image as a DataFrame Row with schema==ImageSchema. """ - return sparkModeLookup[imageRow.mode] + # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists. + if len(imgArray.shape) == 4: + if imgArray.shape[0] != 1: + raise ValueError( + "The first dimension of a 4-d image array is expected to be 1.") + imgArray = imgArray.reshape(imgArray.shape[1:]) + imageType = _arrayToOcvMode(imgArray) + height, width, nChannels = imgArray.shape + data = bytearray(imgArray.tobytes()) + return Row(origin=origin, mode=imageType.ord, height=height, + width=width, nChannels=nChannels, data=data) def imageStructToArray(imageRow): @@ -118,89 +95,98 @@ def imageStructToArray(imageRow): :param imageRow: Row, must use imageSchema. :return: ndarray, image data. """ - imType = imageType(imageRow) + imType = imageTypeByOrdinal(imageRow.mode) shape = (imageRow.height, imageRow.width, imageRow.nChannels) return np.ndarray(shape, imType.dtype, imageRow.data) -def _arrayToSparkMode(arr): - assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(arr.shape) - num_channels = arr.shape[2] - if num_channels == 1: - if arr.dtype not in [np.float16, np.float32, np.float64]: - raise ValueError("incompatible dtype (%s) for numpy array for float32 mode" % - arr.dtype.string) - return SparkMode.FLOAT32 - elif num_channels != 3: - raise ValueError("number of channels of the input array (%d) is not supported" % - num_channels) - elif arr.dtype == np.uint8: - return SparkMode.RGB - elif arr.dtype in [np.float16, np.float32, np.float64]: - return SparkMode.RGB_FLOAT32 +def imageStructToPIL(imageRow): + """ + Convert the immage from image schema struct to PIL image + + :param imageRow: Row, must have ImageSchema + :return PIL image + """ + imgType = imageTypeByOrdinal(imageRow.mode) + if imgType.dtype != 'uint8': + raise ValueError("Can not convert image of type " + + imgType.dtype + " to PIL, can only deal with 8U format") + ary = imageStructToArray(imageRow) + # PIL expects RGB order, image schema is BGR + # => we need to flip the order unless there is only one channel + if imgType.nChannels != 1: + ary = _reverseChannels(ary) + if imgType.nChannels == 1: + return Image.fromarray(obj=ary, mode='L') + elif imgType.nChannels == 3: + return Image.fromarray(obj=ary, mode='RGB') + elif imgType.nChannels == 4: + return Image.fromarray(obj=ary, mode='RGBA') else: - raise ValueError("did not find a sparkMode for the given array with num_channels = %d " + - "and dtype %s" % (num_channels, arr.dtype.string)) + raise ValueError("don't know how to convert " + + imgType.name + " to PIL") -def _resizeFunction(size): - """ Creates a resize function. - - :param size: tuple, size of new image: (height, width). - :return: function: image => image, a function that converts an input image to an image with - of `size`. - """ +def PIL_to_imageStruct(img): + # PIL is RGB based, image schema expects BGR ordering => need to flip the channels + return _reverseChannels(np.asarray(img)) - if len(size) != 2: - raise ValueError("New image size should have for [hight, width] but got {}".format(size)) - def resizeImageAsRow(imgAsRow): - imgAsArray = imageStructToArray(imgAsRow) - imgType = imageType(imgAsRow) - imgAsPil = Image.fromarray(imgAsArray, imgType.pilMode) - imgAsPil = imgAsPil.resize(size[::-1]) - imgAsArray = np.array(imgAsPil) - return imageArrayToStruct(imgAsArray, imgType.sparkMode) +def _arrayToOcvMode(arr): + assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format( + arr.shape) + num_channels = arr.shape[2] + if arr.dtype == "uint8": + name = "CV_8UC%d" % num_channels + elif arr.dtype == "float32": + name = "CV_32FC%d" % num_channels + else: + raise ValueError("Unsupported type '%s'" % arr.dtype) + return imageTypeByName(name) + + +def fixColorChannelOrdering(currentOrder, imgAry): + if currentOrder == 'RGB': + return _reverseChannels(imgAry) + elif currentOrder == 'BGR': + return imgAry + elif currentOrder == 'L': + if len(img.shape) != 1: + raise ValueError( + "channel order suggests only one color channel but got shape " + str(img.shape)) + return imgAry + else: + raise ValueError( + "Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder) + - return resizeImageAsRow +def _reverseChannels(ary): + return ary[..., ::-1] -def resizeImage(size): +def createResizeImageUDF(size): """ Create a udf for resizing image. - + Example usage: dataFrame.select(resizeImage((height, width))('imageColumn')) - - :param size: tuple, target size of new image in the form (height, width). - :return: udf, a udf for resizing an image column to `size`. - """ - return udf(_resizeFunction(size), imageSchema) - -def _decodeImage(imageData): - """ - Decode compressed image data into a DataFrame image row. - - :param imageData: (bytes, bytearray) compressed image data in PIL compatible format. - :return: Row, decoded image. + :param size: tuple, target size of new image in the form (height, width). + :return: udf, a udf for resizing an image column to `size`. """ - try: - img = Image.open(BytesIO(imageData)) - except IOError: - return None + if len(size) != 2: + raise ValueError( + "New image size should have format [height, width] but got {}".format(size)) + sz = (size[1], size[0]) - if img.mode in pilModeLookup: - mode = pilModeLookup[img.mode] - else: - msg = "We don't currently support images with mode: {mode}" - warn(msg.format(mode=img.mode)) - return None - imgArray = np.asarray(img) - image = imageArrayToStruct(imgArray, mode.sparkMode) - return image + def _resizeImageAsRow(imgAsRow): + if (imgAsRow.height, imgAsRow.width) == sz: + return imgAsRow + imgAsPil = imageStructToPIL(imgAsRow).resize(sz) + # PIL is RGB based while image schema is BGR based => we need to flip the channels + imgAsArray = _reverseChannels(np.asarray(imgAsPil)) + return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin) + return udf(_resizeImageAsRow, ImageSchema.imageSchema['image'].dataType) -# Creating a UDF on import can cause SparkContext issues sometimes. -# decodeImage = udf(_decodeImage, imageSchema) def filesToDF(sc, path, numPartitions=None): """ @@ -214,24 +200,51 @@ def filesToDF(sc, path, numPartitions=None): numPartitions = numPartitions or sc.defaultParallelism schema = StructType([StructField("filePath", StringType(), False), StructField("fileData", BinaryType(), False)]) - rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions) + rdd = sc.binaryFiles( + path, minPartitions=numPartitions).repartition(numPartitions) rdd = rdd.map(lambda x: (x[0], bytearray(x[1]))) return rdd.toDF(schema) -def readImages(imageDirectory, numPartition=None): +def PIL_decode(raw_bytes): + """ + Decode a raw image bytes using PIL. + :param raw_bytes: + :return: image data as an array in CV_8UC3 format """ - Read a directory of images (or a single image) into a DataFrame. + return PIL_to_imageStruct(Image.open(BytesIO(raw_bytes))) - :param sc: spark context - :param imageDirectory: str, file path. - :param numPartition: int, number or partitions to use for reading files. - :return: DataFrame, with columns: (filepath: str, image: imageSchema). + +def PIL_decode_and_resize(size): + """ + Decode a raw image bytes using PIL and resize it to target dimension, both using PIL. + :param raw_bytes: + :return: image data as an array in CV_8UC3 format """ - return _readImages(imageDirectory, numPartition, SparkContext.getOrCreate()) + def _decode(raw_bytes): + return PIL_to_imageStruct(Image.open(BytesIO(raw_bytes)).resize(size)) + return _decode -def _readImages(imageDirectory, numPartition, sc): - decodeImage = udf(_decodeImage, imageSchema) - imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition) - return imageData.select("filePath", decodeImage("fileData").alias("image")) +def readImagesWithCustomFn(path, decode_f, numPartition=None): + """ + Read a directory of images (or a single image) into a DataFrame using a custom library to decode the images. + + :param path: str, file path. + :param decode_f: function to decode the raw bytes into an array compatible with one of the supported OpenCv modes. + see @imageIO.PIL_decode for an example. + :param numPartition: [optional] int, number or partitions to use for reading files. + :return: DataFrame with schema == ImageSchema.imageSchema. + """ + return _readImagesWithCustomFn(path, decode_f, numPartition, sc=SparkContext.getOrCreate()) + + +def _readImagesWithCustomFn(path, decode_f, numPartition, sc): + def _decode(path, raw_bytes): + try: + return imageArrayToStruct(decode_f(raw_bytes), origin=path) + except BaseException: + return None + decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType) + imageData = filesToDF(sc, path, numPartitions=numPartition) + return imageData.select(decodeImage("filePath", "fileData").alias("image")) diff --git a/python/sparkdl/param/converters.py b/python/sparkdl/param/converters.py index 25a2e3a1..dd977005 100644 --- a/python/sparkdl/param/converters.py +++ b/python/sparkdl/param/converters.py @@ -35,6 +35,7 @@ __all__ = ['SparkDLTypeConverters'] + class SparkDLTypeConverters(object): """ .. note:: DeveloperApi @@ -167,6 +168,13 @@ def toKerasOptimizer(value): return value + @staticmethod + def toChannelOrder(value): + if not value in ('L', 'RGB', 'BGR'): + raise ValueError( + "Unsupported channel order. Expected one of ('L', 'RGB', 'BGR') but got '%s'") % value + return value + def _check_is_tensor_name(_maybe_tnsr_name): """ Check if the input is a valid tensor name or raise a `TypeError` otherwise. """ diff --git a/python/sparkdl/param/image_params.py b/python/sparkdl/param/image_params.py index 6ca2ff6d..0e7dcdbc 100644 --- a/python/sparkdl/param/image_params.py +++ b/python/sparkdl/param/image_params.py @@ -19,14 +19,16 @@ private APIs. """ +from sparkdl.image.image import ImageSchema from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql.functions import udf - -from sparkdl.image.imageIO import imageArrayToStruct, imageSchema +from sparkdl.image.imageIO import imageArrayToStruct +from sparkdl.image.imageIO import _reverseChannels from sparkdl.param import SparkDLTypeConverters OUTPUT_MODES = ["vector", "image"] + class HasInputImageNodeName(Params): # TODO: docs inputImageNodeName = Param(Params._dummy(), "inputImageNodeName", @@ -39,6 +41,7 @@ def setInputImageNodeName(self, value): def getInputImageNodeName(self): return self.getOrDefault(self.inputImageNodeName) + class CanLoadImage(Params): """ In standard Keras workflow, we use provides an image loading function @@ -71,7 +74,7 @@ def image_loader(uri): imageLoader = Param(Params._dummy(), "imageLoader", "Function containing the logic for loading and pre-processing images. " + "The function should take in a URI string and return a 4-d numpy.array " + - "with shape (batch_size (1), height, width, num_channels).") + "with shape (batch_size (1), height, width, num_channels). Expected to return result with color channels in RGB order.") def setImageLoader(self, value): return self._set(imageLoader=value) @@ -89,15 +92,14 @@ def loadImagesInternal(self, dataframe, inputCol): # plan 1: udf(loader() + convert from np.array to imageSchema) -> call TFImageTransformer # plan 2: udf(loader()) ... we don't support np.array as a dataframe column type... loader = self.getImageLoader() - # Load from external resources can fail, so we should allow None to be returned + def load_image_uri_impl(uri): try: - return imageArrayToStruct(loader(uri)) - except: # pylint: disable=bare-except + return imageArrayToStruct(_reverseChannels(loader(uri))) + except BaseException: # pylint: disable=bare-except return None - - load_udf = udf(load_image_uri_impl, imageSchema) + load_udf = udf(load_image_uri_impl, ImageSchema.imageSchema['image'].dataType) return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol])) diff --git a/python/sparkdl/param/shared_params.py b/python/sparkdl/param/shared_params.py index ad3aa2aa..64f6c106 100644 --- a/python/sparkdl/param/shared_params.py +++ b/python/sparkdl/param/shared_params.py @@ -178,7 +178,7 @@ def _loadTFGraph(self, sess, graph): keras_backend = K.backend() assert keras_backend == "tensorflow", \ "Only tensorflow-backed Keras models are supported, tried to load Keras model " \ - "with backend %s."%(keras_backend) + "with backend %s." % (keras_backend) with graph.as_default(): K.set_learning_phase(0) # Inference phase model = load_model(self.getModelFile()) diff --git a/python/sparkdl/transformers/keras_applications.py b/python/sparkdl/transformers/keras_applications.py index 87f0eeab..d88a1d5b 100644 --- a/python/sparkdl/transformers/keras_applications.py +++ b/python/sparkdl/transformers/keras_applications.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2017 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +15,10 @@ # # -# Models marked below as provided by Keras are provided subject to the -# below copyright and licenses (and any additional copyrights and +# Models marked below as provided by Keras are provided subject to the +# below copyright and licenses (and any additional copyrights and # licenses specified). -# +# # COPYRIGHT # # All contributions by François Chollet: @@ -71,12 +72,15 @@ import tensorflow as tf from sparkdl.transformers.utils import (imageInputPlaceholder, InceptionV3Constants) +from sparkdl.image.imageIO import _reverseChannels """ Essentially a factory function for getting the correct KerasApplicationModel class for the network name. """ + + def getKerasApplicationModel(name): try: return KERAS_APPLICATION_MODELS[name]() @@ -108,9 +112,9 @@ def preprocess(self, inputImage): @abstractmethod def model(self, preprocessed, featurize): """ - Models marked as *provided by Keras* are provided subject to the MIT + Models marked as *provided by Keras* are provided subject to the MIT license located at https://github.com/fchollet/keras/blob/master/LICENSE - and subject to any additional copyrights and licenses specified in the + and subject to any additional copyrights and licenses specified in the code or documentation. """ pass @@ -135,7 +139,8 @@ def _testKerasModel(self, include_top): class InceptionV3Model(KerasApplicationModel): def preprocess(self, inputImage): - return inception_v3.preprocess_input(inputImage) + # Keras expects RGB order + return inception_v3.preprocess_input(_reverseChannels(inputImage)) def model(self, preprocessed, featurize): # Model provided by Keras. All cotributions by Keras are provided subject to the @@ -167,9 +172,11 @@ def inputShape(self): def _testKerasModel(self, include_top): return inception_v3.InceptionV3(weights="imagenet", include_top=include_top) + class XceptionModel(KerasApplicationModel): def preprocess(self, inputImage): - return xception.preprocess_input(inputImage) + # Keras expects RGB order + return xception.preprocess_input(_reverseChannels(inputImage)) def model(self, preprocessed, featurize): # Model provided by Keras. All cotributions by Keras are provided subject to the @@ -183,6 +190,7 @@ def inputShape(self): def _testKerasModel(self, include_top): return xception.Xception(weights="imagenet", include_top=include_top) + class ResNet50Model(KerasApplicationModel): def preprocess(self, inputImage): return _imagenet_preprocess_input(inputImage, self.inputShape()) @@ -222,6 +230,7 @@ def inputShape(self): def _testKerasModel(self, include_top): return resnet50.ResNet50(weights="imagenet", include_top=include_top) + class VGG16Model(KerasApplicationModel): def preprocess(self, inputImage): return _imagenet_preprocess_input(inputImage, self.inputShape()) @@ -232,7 +241,7 @@ def model(self, preprocessed, featurize): # and subject to the below additional copyrights and licenses. # # Copyright 2014 Oxford University - # + # # Licensed under the Creative Commons Attribution License CC BY 4.0 ("License"). # You may obtain a copy of the License at # @@ -247,6 +256,7 @@ def inputShape(self): def _testKerasModel(self, include_top): return vgg16.VGG16(weights="imagenet", include_top=include_top) + class VGG19Model(KerasApplicationModel): def preprocess(self, inputImage): return _imagenet_preprocess_input(inputImage, self.inputShape()) @@ -257,7 +267,7 @@ def model(self, preprocessed, featurize): # and subject to the below additional copyrights and licenses. # # Copyright 2014 Oxford University - # + # # Licensed under the Creative Commons Attribution License CC BY 4.0 ("License"). # You may obtain a copy of the License at # @@ -280,10 +290,9 @@ def _imagenet_preprocess_input(x, input_shape): works okay with tf.Tensor inputs. The following was translated to tf ops from https://github.com/fchollet/keras/blob/fb4a0849cf4dc2965af86510f02ec46abab1a6a4/keras/applications/imagenet_utils.py#L52 It's a possibility to change the implementation in keras to look like the - following, but not doing it for now. + following and modified to work with BGR images (standard in Spark), but not doing it for now. """ - # 'RGB'->'BGR' - x = x[..., ::-1] + # assuming 'BGR' # Zero-center by mean pixel mean = np.ones(input_shape + (3,), dtype=np.float32) mean[..., 0] = 103.939 @@ -291,6 +300,7 @@ def _imagenet_preprocess_input(x, input_shape): mean[..., 2] = 123.68 return x - mean + KERAS_APPLICATION_MODELS = { "InceptionV3": InceptionV3Model, "Xception": XceptionModel, diff --git a/python/sparkdl/transformers/keras_image.py b/python/sparkdl/transformers/keras_image.py index 0b2f7e6b..cbfa3fb8 100644 --- a/python/sparkdl/transformers/keras_image.py +++ b/python/sparkdl/transformers/keras_image.py @@ -61,7 +61,7 @@ def _transform(self, dataset): graph, inputTensorName, outputTensorName = self._loadTFGraph(sess=sess, graph=keras_graph) image_df = self.loadImagesInternal(dataset, self.getInputCol()) - transformer = TFImageTransformer(inputCol=self._loadedImageCol(), + transformer = TFImageTransformer(channelOrder='RGB', inputCol=self._loadedImageCol(), outputCol=self.getOutputCol(), graph=graph, inputTensor=inputTensorName, outputTensor=outputTensorName, diff --git a/python/sparkdl/transformers/keras_tensor.py b/python/sparkdl/transformers/keras_tensor.py index 1134ff73..63fd0036 100644 --- a/python/sparkdl/transformers/keras_tensor.py +++ b/python/sparkdl/transformers/keras_tensor.py @@ -60,6 +60,6 @@ def _transform(self, dataset): fetch_names=[outputTensorName]) # Create TFTransformer & use it to apply the loaded Keras model graph to our dataset transformer = TFTransformer(tfInputGraph=inputGraph, - inputMapping={self.getInputCol() : inputTensorName}, + inputMapping={self.getInputCol(): inputTensorName}, outputMapping={outputTensorName: self.getOutputCol()}) return transformer.transform(dataset) diff --git a/python/sparkdl/transformers/keras_utils.py b/python/sparkdl/transformers/keras_utils.py index 1e89f661..39a69e59 100644 --- a/python/sparkdl/transformers/keras_utils.py +++ b/python/sparkdl/transformers/keras_utils.py @@ -26,13 +26,13 @@ class KSessionWrap(): ... do some things that call Keras """ - def __init__(self, graph = None): + def __init__(self, graph=None): self.requested_graph = graph def __enter__(self): self.old_session = K.get_session() self.g = self.requested_graph or tf.Graph() - self.current_session = tf.Session(graph = self.g) + self.current_session = tf.Session(graph=self.g) K.set_session(self.current_session) return (self.current_session, self.g) diff --git a/python/sparkdl/transformers/named_image.py b/python/sparkdl/transformers/named_image.py index a7e87a64..b4534684 100644 --- a/python/sparkdl/transformers/named_image.py +++ b/python/sparkdl/transformers/named_image.py @@ -22,14 +22,14 @@ from pyspark.sql.types import (ArrayType, FloatType, StringType, StructField, StructType) import sparkdl.graph.utils as tfx -from sparkdl.image.imageIO import resizeImage +from sparkdl.image.imageIO import createResizeImageUDF import sparkdl.transformers.keras_applications as keras_apps from sparkdl.param import ( keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters) from sparkdl.transformers.tf_image import TFImageTransformer -SUPPORTED_MODELS = ["InceptionV3", "Xception", "ResNet50","VGG16","VGG19"] +SUPPORTED_MODELS = ["InceptionV3", "Xception", "ResNet50", "VGG16", "VGG19"] class DeepImagePredictor(Transformer, HasInputCol, HasOutputCol): @@ -94,6 +94,7 @@ def _decodeOutputAsPredictions(self, df): # Also, we could put the computation directly in the main computation # graph or use a scala UDF for potentially better performance. topK = self.getOrDefault(self.topK) + def decode(predictions): pred_arr = np.expand_dims(np.array(predictions), axis=0) decoded = decode_predictions(pred_arr, top=topK)[0] @@ -211,13 +212,15 @@ def _transform(self, dataset): modelGraphSpec = _buildTFGraphForName(self.getModelName(), self.getFeaturize()) inputCol = self.getInputCol() resizedCol = "__sdl_imagesResized" - tfTransformer = TFImageTransformer(inputCol=resizedCol, - outputCol=self.getOutputCol(), - graph=modelGraphSpec["graph"], - inputTensor=modelGraphSpec["inputTensorName"], - outputTensor=modelGraphSpec["outputTensorName"], - outputMode=modelGraphSpec["outputMode"]) - resizeUdf = resizeImage(modelGraphSpec["inputTensorSize"]) + tfTransformer = TFImageTransformer( + channelOrder='BGR', + inputCol=resizedCol, + outputCol=self.getOutputCol(), + graph=modelGraphSpec["graph"], + inputTensor=modelGraphSpec["inputTensorName"], + outputTensor=modelGraphSpec["outputTensorName"], + outputMode=modelGraphSpec["outputMode"]) + resizeUdf = createResizeImageUDF(modelGraphSpec["inputTensorSize"]) result = tfTransformer.transform(dataset.withColumn(resizedCol, resizeUdf(inputCol))) return result.drop(resizedCol) diff --git a/python/sparkdl/transformers/tf_image.py b/python/sparkdl/transformers/tf_image.py index 152a7fea..3796288e 100644 --- a/python/sparkdl/transformers/tf_image.py +++ b/python/sparkdl/transformers/tf_image.py @@ -17,16 +17,20 @@ import tensorflow as tf import tensorframes as tfs +from pyspark import Row from pyspark.ml import Transformer from pyspark.ml.param import Param, Params from pyspark.sql.functions import udf -from sparkdl.image.imageIO import imageSchema, sparkModeLookup, SparkMode +import sparkdl.graph.utils as tfx +import sparkdl.image.imageIO as imageIO from sparkdl.param import ( keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters, HasOutputMode) import sparkdl.transformers.utils as utils import sparkdl.utils.jvmapi as JVMAPI -import sparkdl.graph.utils as tfx + +from sparkdl.image.image import ImageSchema + __all__ = ['TFImageTransformer'] @@ -34,6 +38,7 @@ USER_GRAPH_NAMESPACE = 'given' NEW_OUTPUT_PREFIX = 'sdl_flattened' + class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode): """ Applies the Tensorflow graph to the image column in DataFrame. @@ -61,22 +66,27 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode): outputTensor = Param(Params._dummy(), "outputTensor", "A TensorFlow tensor object or name representing the output", typeConverter=SparkDLTypeConverters.toTFTensorName) + channelOrder = Param(Params._dummy(), "channelOrder", + "Strign specifying the expected color channel order, can be one of L,RGB,BGR", + typeConverter=SparkDLTypeConverters.toChannelOrder) @keyword_only - def __init__(self, inputCol=None, outputCol=None, graph=None, + def __init__(self, channelOrder, inputCol=None, outputCol=None, graph=None, inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector"): """ - __init__(self, inputCol=None, outputCol=None, graph=None, + __init__(self, channelOrder, inputCol=None, outputCol=None, graph=None, inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector") + :param: channelOrder: specify the ordering of the color channel, can be one of RGB, BGR, L (grayscale) """ super(TFImageTransformer, self).__init__() kwargs = self._input_kwargs self.setParams(**kwargs) + self.channelOrder = channelOrder @keyword_only - def setParams(self, inputCol=None, outputCol=None, graph=None, + def setParams(self, channelOrder=None, inputCol=None, outputCol=None, graph=None, inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector"): """ @@ -117,11 +127,11 @@ def _transform(self, dataset): with final_graph.as_default(): image = dataset[self.getInputCol()] image_df_exploded = (dataset - .withColumn("__sdl_image_height", image.height) - .withColumn("__sdl_image_width", image.width) - .withColumn("__sdl_image_nchannels", image.nChannels) - .withColumn("__sdl_image_data", image.data) - ) + .withColumn("__sdl_image_height", image.height) + .withColumn("__sdl_image_width", image.width) + .withColumn("__sdl_image_nchannels", image.nChannels) + .withColumn("__sdl_image_data", image.data) + ) final_output_name = self._getFinalOutputTensorName() output_tensor = final_graph.get_tensor_by_name(final_output_name) @@ -152,9 +162,11 @@ def _getImageDtype(self, dataset): # Assumes that the dtype for all images is the same in the given dataframe. pdf = dataset.select(self.getInputCol()).take(1) img = pdf[0][self.getInputCol()] - img_type = sparkModeLookup[img.mode] + img_type = imageIO.imageTypeByOrdinal(img.mode) return img_type.dtype + # TODO: duplicate code, same functionality as sparkdl.graph.pieces.py::builSpImageConverter + # TODO: It should be extracted as a util function and shared def _addReshapeLayers(self, tf_graph, dtype="uint8"): input_tensor_name = self.getInputTensor().name @@ -174,9 +186,10 @@ def _addReshapeLayers(self, tf_graph, dtype="uint8"): image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw") image_float = tf.to_float(image_uint8) else: - assert dtype == SparkMode.FLOAT32, "Unsupported dtype for image: %s" % dtype + assert dtype == "float32", "Unsupported dtype for image: %s" % dtype image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw") image_reshaped = tf.reshape(image_float, shape, name="reshaped") + image_reshaped = imageIO.fixColorChannelOrdering(self.channelOrder, image_reshaped) image_reshaped_expanded = tf.expand_dims(image_reshaped, 0, name="expanded") # Add on the original graph @@ -213,17 +226,18 @@ def _convertOutputToImage(self, df, tfs_output_col, output_shape): assert len(output_shape) == 4, str(output_shape) + " does not have 4 dimensions" height = int(output_shape[1]) width = int(output_shape[2]) + def to_image(orig_image, numeric_data): # Assume the returned image has float pixels but same #channels as input - mode = orig_image.mode if orig_image.mode == "float32" else "RGB-float32" - return [mode, height, width, orig_image.nChannels, - bytearray(np.array(numeric_data).astype(np.float32).tobytes())] - to_image_udf = udf(to_image, imageSchema) - return ( - df.withColumn(self.getOutputCol(), - to_image_udf(df[self.getInputCol()], df[tfs_output_col])) - .drop(tfs_output_col) - ) + mode = imageIO.imageTypeByName('CV_32FC%d' % orig_image.nChannels) + data = bytearray(np.array(numeric_data).astype(np.float32).tobytes()) + nChannels = orig_image.nChannels + return Row(origin="", mode=mode.ord, height=height, + width=width, nChannels=nChannels, data=data) + to_image_udf = udf(to_image, ImageSchema.imageSchema['image'].dataType) + resDf = df.withColumn(self.getOutputCol(), to_image_udf( + df[self.getInputCol()], df[tfs_output_col])) + return resDf.drop(tfs_output_col) def _convertOutputToVector(self, df, tfs_output_col): """ diff --git a/python/sparkdl/transformers/tf_tensor.py b/python/sparkdl/transformers/tf_tensor.py index 3affabac..f144ecb5 100644 --- a/python/sparkdl/transformers/tf_tensor.py +++ b/python/sparkdl/transformers/tf_tensor.py @@ -30,6 +30,7 @@ logger = logging.getLogger('sparkdl') + class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping): """ Applies the TensorFlow graph to the array column in DataFrame. @@ -138,7 +139,8 @@ def _transform(self, dataset): tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) # Feed dict maps from placeholder name to DF column name - feed_dict = {self._getSparkDlOpName(tnsr_name) : col_name for col_name, tnsr_name in input_mapping} + feed_dict = {self._getSparkDlOpName( + tnsr_name): col_name for col_name, tnsr_name in input_mapping} fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in out_tnsr_op_names] out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) diff --git a/python/sparkdl/transformers/utils.py b/python/sparkdl/transformers/utils.py index b244365b..f3972559 100644 --- a/python/sparkdl/transformers/utils.py +++ b/python/sparkdl/transformers/utils.py @@ -19,15 +19,19 @@ IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input" + def imageInputPlaceholder(nChannels=None): return tf.placeholder(tf.float32, [None, None, None, nChannels], name=IMAGE_INPUT_PLACEHOLDER_NAME) + class ImageNetConstants: NUM_CLASSES = 1000 # InceptionV3 is used in a lot of tests, so we'll make this shortcut available # For other networks, see the keras_applications module. + + class InceptionV3Constants: INPUT_SHAPE = (299, 299) NUM_OUTPUT_FEATURES = 131072 diff --git a/python/sparkdl/udf/keras_image_model.py b/python/sparkdl/udf/keras_image_model.py index f201417d..f5ba9db0 100644 --- a/python/sparkdl/udf/keras_image_model.py +++ b/python/sparkdl/udf/keras_image_model.py @@ -19,11 +19,15 @@ from sparkdl.graph.builder import GraphFunction, IsolatedSession from sparkdl.graph.pieces import buildSpImageConverter, buildFlattener from sparkdl.graph.tensorframes_udf import makeGraphUDF -from sparkdl.image.imageIO import imageSchema from sparkdl.utils import jvmapi as JVMAPI +import sparkdl.image.imageIO as imageIO + +from sparkdl.image.image import ImageSchema + logger = logging.getLogger('sparkdl') + def registerKerasImageUDF(udf_name, keras_model_or_file_path, preprocessor=None): """ Create a Keras image model as a Spark SQL UDF. @@ -95,10 +99,10 @@ def keras_load_img(fpath): JVMAPI.registerUDF( preproc_udf_name, _serialize_and_reload_with(preprocessor), - imageSchema) + ImageSchema.imageSchema['image'].dataType) keras_udf_name = '{}__model_predict'.format(udf_name) - stages = [('spimg', buildSpImageConverter("RGB")), + stages = [('spimg', buildSpImageConverter('RGB', "uint8")), ('model', GraphFunction.fromKeras(keras_model_or_file_path)), ('final', buildFlattener())] gfn = GraphFunction.fromList(stages) @@ -116,6 +120,7 @@ def keras_load_img(fpath): return gfn + def _serialize_and_reload_with(preprocessor): """ Retruns a function that performs the following steps @@ -131,13 +136,10 @@ def _serialize_and_reload_with(preprocessor): """ def udf_impl(spimg): import numpy as np - from PIL import Image from tempfile import NamedTemporaryFile - from sparkdl.image.imageIO import imageArrayToStruct, imageType + from sparkdl.image.imageIO import imageArrayToStruct - pil_mode = imageType(spimg).pilMode - img_shape = (spimg.width, spimg.height) - img = Image.frombytes(pil_mode, img_shape, bytes(spimg.data)) + img = imageIO.imageStructToPIL(spimg) # Warning: must use lossless format to guarantee consistency temp_fp = NamedTemporaryFile(suffix='.png') img.save(temp_fp, 'PNG') @@ -145,6 +147,9 @@ def udf_impl(spimg): assert isinstance(img_arr_reloaded, np.ndarray), \ "expect preprocessor to return a numpy array" img_arr_reloaded = img_arr_reloaded.astype(np.uint8) + # Keras works in RGB order, need to fix the order + img_arr_reloaded = imageIO.fixColorChannelOrdering( + currentOrder='RGB', imgAry=img_arr_reloaded) return imageArrayToStruct(img_arr_reloaded) return udf_impl diff --git a/python/sparkdl/utils/__init__.py b/python/sparkdl/utils/__init__.py index 459489e1..7084f22b 100644 --- a/python/sparkdl/utils/__init__.py +++ b/python/sparkdl/utils/__init__.py @@ -13,4 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - diff --git a/python/sparkdl/utils/jvmapi.py b/python/sparkdl/utils/jvmapi.py index 82c4e200..2fb4f58c 100644 --- a/python/sparkdl/utils/jvmapi.py +++ b/python/sparkdl/utils/jvmapi.py @@ -24,17 +24,21 @@ logger = logging.getLogger('sparkdl') + def _curr_sql_ctx(sqlCtx=None): _sql_ctx = sqlCtx if sqlCtx is not None else SQLContext._instantiatedContext logger.info("Spark SQL Context = " + str(_sql_ctx)) return _sql_ctx + def _curr_sc(): return SparkContext._active_spark_context + def _curr_jvm(): return _curr_sc()._jvm + def forClass(javaClassName, sqlCtx=None): """ Loads the JVM API object (lazily, because the spark context needs to be initialized @@ -48,6 +52,7 @@ def forClass(javaClassName, sqlCtx=None): jvm_class = jvm_thread.getContextClassLoader().loadClass(javaClassName) return jvm_class.newInstance().sqlContext(_curr_sql_ctx(sqlCtx)._ssql_ctx) + def pyUtils(): """ Exposing Spark PythonUtils @@ -55,20 +60,24 @@ def pyUtils(): """ return _curr_jvm().PythonUtils + def default(): """ Default JVM Python Interface class """ return forClass(javaClassName=PYTHON_INTERFACE_CLASSNAME) + def createTensorFramesModelBuilder(): """ Create TensorFrames model builder using the Scala API """ return forClass(javaClassName=MODEL_FACTORY_CLASSNAME) + def listToMLlibVectorUDF(col): """ Map struct column from list to MLlib vector """ return Column(default().listToMLlibVectorUDF(col._jc)) # pylint: disable=W0212 + def registerPipeline(name, ordered_udf_names): - """ + """ Given a sequence of @ordered_udf_names f1, f2, ..., fn Create a pipelined UDF as fn(...f2(f1())) """ @@ -76,6 +85,7 @@ def registerPipeline(name, ordered_udf_names): "must provide more than one ordered udf names" return default().registerPipeline(name, ordered_udf_names) + def registerUDF(name, function_body, schema): """ Register a single UDF """ return _curr_sql_ctx().registerFunction(name, function_body, schema) diff --git a/python/sparkdl/utils/keras_model.py b/python/sparkdl/utils/keras_model.py index 0425a068..f8f4371b 100644 --- a/python/sparkdl/utils/keras_model.py +++ b/python/sparkdl/utils/keras_model.py @@ -25,6 +25,7 @@ __all__ = ['model_to_bytes', 'bytes_to_model', 'bytes_to_h5file', 'is_valid_loss_function', 'is_valid_optimizer'] + def model_to_bytes(model): """ Serialize the Keras model to HDF5 and load the file as bytes. @@ -41,6 +42,7 @@ def model_to_bytes(model): shutil.rmtree(temp_dir, ignore_errors=True) return file_bytes + def bytes_to_h5file(modelBytes): """ Dump HDF5 file content bytes to a local file @@ -52,6 +54,7 @@ def bytes_to_h5file(modelBytes): fout.write(modelBytes) return temp_path + def bytes_to_model(modelBytes, remove_temp_path=True): """ Convert a Keras model from a byte string to a Keras model instance. @@ -66,6 +69,7 @@ def bytes_to_model(modelBytes, remove_temp_path=True): shutil.rmtree(temp_dir, ignore_errors=True) return model + def _get_loss_function(identifier): """ Retrieves a Keras loss function instance. @@ -74,6 +78,7 @@ def _get_loss_function(identifier): """ return keras.losses.get(identifier) + def is_valid_loss_function(identifier): """ Check if a named loss function is supported in Keras """ try: @@ -82,6 +87,7 @@ def is_valid_loss_function(identifier): except ValueError: return False + def _get_optimizer(identifier): """ Retrieves a Keras Optimizer instance. @@ -90,6 +96,7 @@ def _get_optimizer(identifier): """ return keras.optimizers.get(identifier) + def is_valid_optimizer(identifier): """ Check if a named optimizer is supported in Keras """ try: diff --git a/python/tests/__init__.py b/python/tests/__init__.py index a464d43a..97c96d6c 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sparkdl diff --git a/python/tests/estimators/test_keras_estimators.py b/python/tests/estimators/test_keras_estimators.py index f6fcbff5..1ae11066 100644 --- a/python/tests/estimators/test_keras_estimators.py +++ b/python/tests/estimators/test_keras_estimators.py @@ -37,6 +37,7 @@ from ..tests import SparkDLTestCase from ..transformers.image_utils import getSampleImagePaths + def _load_image_from_uri(local_uri): img = (PIL.Image .open(local_uri) @@ -46,6 +47,7 @@ def _load_image_from_uri(local_uri): img_tnsr = preprocess_input(img_arr[np.newaxis, :]) return img_tnsr + class KerasEstimatorsTest(SparkDLTestCase): def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100): diff --git a/python/tests/graph/test_builder.py b/python/tests/graph/test_builder.py index 93b3c9f5..301da133 100644 --- a/python/tests/graph/test_builder.py +++ b/python/tests/graph/test_builder.py @@ -104,7 +104,6 @@ def test_import_export_graph_function(self): self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names) self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def)) - def test_keras_consistency(self): """ Exported model in Keras should get same result as original """ diff --git a/python/tests/graph/test_import.py b/python/tests/graph/test_import.py index 36501568..e92d87ac 100644 --- a/python/tests/graph/test_import.py +++ b/python/tests/graph/test_import.py @@ -29,7 +29,7 @@ class TestGraphImport(object): def test_graph_novar(self): gin = _build_graph_input(lambda session: - TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], + TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name])) _check_input_novar(gin) @@ -80,7 +80,6 @@ def test_saved_model_iomap(self): and _translated_output_mapping == _expected_output_mapping, \ err_msg.format(_translated_input_mapping, _translated_output_mapping) - def test_saved_graph_novar(self): with _make_temp_directory() as tmp_dir: saved_model_dir = os.path.join(tmp_dir, 'saved_model') @@ -123,8 +122,8 @@ def gin_fun(session): def test_graphdef_novar_2(self): gin = _build_graph_input_2(lambda session: - TFInputGraph.fromGraphDef(session.graph.as_graph_def(), - [_tensor_input_name], [_tensor_output_name])) + TFInputGraph.fromGraphDef(session.graph.as_graph_def(), + [_tensor_input_name], [_tensor_output_name])) _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) def test_saved_graph_novar_2(self): @@ -138,6 +137,7 @@ def gin_fun(session): gin = _build_graph_input_2(gin_fun) _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) + _serving_tag = "serving_tag" _serving_sigdef_key = 'prediction_signature' # The name of the input tensor diff --git a/python/tests/graph/test_pieces.py b/python/tests/graph/test_pieces.py index 9d659265..1395bf41 100644 --- a/python/tests/graph/test_pieces.py +++ b/python/tests/graph/test_pieces.py @@ -23,6 +23,7 @@ import numpy as np import numpy.random as prng import tensorflow as tf + import keras.backend as K from keras.applications import InceptionV3 from keras.applications import inception_v3 as iv3 @@ -36,10 +37,12 @@ from pyspark.sql import DataFrame, Row from pyspark.sql.functions import udf -from sparkdl.image.imageIO import imageArrayToStruct, SparkMode from sparkdl.graph.builder import IsolatedSession, GraphFunction import sparkdl.graph.pieces as gfac import sparkdl.graph.utils as tfx +from sparkdl.image.imageIO import imageArrayToStruct +from sparkdl.image.imageIO import imageTypeByOrdinal + from ..tests import SparkDLTestCase from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF @@ -52,17 +55,19 @@ def test_spimage_converter_module(self): img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) def exec_gfn_spimg_decode(spimg_dict, img_dtype): - gfn = gfac.buildSpImageConverter(img_dtype) + gfn = gfac.buildSpImageConverter('BGR', img_dtype) with IsolatedSession() as issn: feeds, fetches = issn.importGraphFunction(gfn, prefix="") - feed_dict = dict((tnsr, spimg_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) + feed_dict = dict( + (tnsr, spimg_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) img_out = issn.run(fetches[0], feed_dict=feed_dict) return img_out def check_image_round_trip(img_arr): spimg_dict = imageArrayToStruct(img_arr).asDict() spimg_dict['data'] = bytes(spimg_dict['data']) - img_arr_out = exec_gfn_spimg_decode(spimg_dict, spimg_dict['mode']) + img_arr_out = exec_gfn_spimg_decode( + spimg_dict, imageTypeByOrdinal(spimg_dict['mode']).dtype) self.assertTrue(np.all(img_arr_out == img_arr)) for fp in img_fpaths: @@ -71,7 +76,7 @@ def check_image_round_trip(img_arr): img_arr_byte = img_to_array(img).astype(np.uint8) check_image_round_trip(img_arr_byte) - img_arr_float = img_to_array(img).astype(np.float) + img_arr_float = img_to_array(img).astype(np.float32) check_image_round_trip(img_arr_float) img_arr_preproc = iv3.preprocess_input(img_to_array(img)) @@ -143,7 +148,7 @@ def test_pipeline(self): img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) xcpt_model = Xception(weights="imagenet") - stages = [('spimage', gfac.buildSpImageConverter(SparkMode.RGB_FLOAT32)), + stages = [('spimage', gfac.buildSpImageConverter('BGR', 'float32')), ('xception', GraphFunction.fromKeras(xcpt_model))] piped_model = GraphFunction.fromList(stages) @@ -159,7 +164,8 @@ def test_pipeline(self): with IsolatedSession() as issn: # Need blank import scope name so that spimg fields match the input names feeds, fetches = issn.importGraphFunction(piped_model, prefix="") - feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) + feed_dict = dict( + (tnsr, spimg_input_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) preds_tgt = issn.run(fetches[0], feed_dict=feed_dict) # Uncomment the line below to see the graph # tfx.write_visualization_html(issn.graph, diff --git a/python/tests/image/test_imageIO.py b/python/tests/image/test_imageIO.py index 0b6f6e61..2c172603 100644 --- a/python/tests/image/test_imageIO.py +++ b/python/tests/image/test_imageIO.py @@ -24,9 +24,12 @@ from pyspark.sql.types import BinaryType, StringType, StructField, StructType from sparkdl.image import imageIO +from sparkdl.image.image import ImageSchema from ..tests import SparkDLTestCase -# Create dome fake image data to work with +# Create some fake image data to work with + + def create_image_data(): # Random image-like data array = np.random.randint(0, 256, (10, 11, 3), 'uint8') @@ -38,7 +41,9 @@ def create_image_data(): # Get Png data as stream pngData = imgFile.read() - return array, pngData + # PIL is RGB but image schema is BGR => flip the channels + return imageIO._reverseChannels(array), pngData + array, pngData = create_image_data() @@ -72,71 +77,49 @@ def tearDownClass(cls): super(TestReadImages, cls).tearDownClass() cls.binaryFilesMock = None - def test_decodeImage(self): - badImg = imageIO._decodeImage(b"xxx") - self.assertIsNone(badImg) - imgRow = imageIO._decodeImage(pngData) - self.assertIsNotNone(imgRow) - self.assertEqual(len(imgRow), len(imageIO.imageSchema.names)) - for n in imageIO.imageSchema.names: - imgRow[n] - def test_resize(self): + self.assertRaises(ValueError, imageIO.createResizeImageUDF, [1, 2, 3]) + + make_smaller = imageIO.createResizeImageUDF([4, 5]).func imgAsRow = imageIO.imageArrayToStruct(array) - smaller = imageIO._resizeFunction([4, 5]) - smallerImg = smaller(imgAsRow) - for n in imageIO.imageSchema.names: - smallerImg[n] + smallerImg = make_smaller(imgAsRow) self.assertEqual(smallerImg.height, 4) self.assertEqual(smallerImg.width, 5) - sameImage = imageIO._resizeFunction([imgAsRow.height, imgAsRow.width])(imgAsRow) - self.assertEqual(sameImage, sameImage) - - self.assertRaises(ValueError, imageIO._resizeFunction, [1, 2, 3]) - - def test_imageArrayToStruct(self): - SparkMode = imageIO.SparkMode - # Check converting with matching types - height, width, chan = array.shape - imgAsStruct = imageIO.imageArrayToStruct(array) - self.assertEqual(imgAsStruct.height, height) - self.assertEqual(imgAsStruct.width, width) - self.assertEqual(imgAsStruct.data, array.tobytes()) - - # Check casting - imgAsStruct = imageIO.imageArrayToStruct(array, SparkMode.RGB_FLOAT32) - self.assertEqual(imgAsStruct.height, height) - self.assertEqual(imgAsStruct.width, width) - self.assertEqual(len(imgAsStruct.data), array.size * 4) - - # Check channel mismatch - self.assertRaises(ValueError, imageIO.imageArrayToStruct, array, SparkMode.FLOAT32) - - # Check that unsafe cast raises error - floatArray = np.zeros((3, 4, 3), dtype='float32') - self.assertRaises(ValueError, imageIO.imageArrayToStruct, floatArray, SparkMode.RGB) - - def test_image_round_trip(self): - # Test round trip: array -> png -> sparkImg -> array - binarySchema = StructType([StructField("data", BinaryType(), False)]) - df = self.session.createDataFrame([[bytearray(pngData)]], binarySchema) - - # Convert to images - decImg = udf(imageIO._decodeImage, imageIO.imageSchema) - imageDF = df.select(decImg("data").alias("image")) - row = imageDF.first() - - testArray = imageIO.imageStructToArray(row.image) - self.assertEqual(testArray.shape, array.shape) - self.assertEqual(testArray.dtype, array.dtype) - self.assertTrue(np.all(array == testArray)) + # Compare to PIL resizing + imgAsPIL = PIL.Image.fromarray(obj=imageIO._reverseChannels(array)).resize((5, 4)) + smallerAry = imageIO._reverseChannels(np.asarray(imgAsPIL)) + np.testing.assert_array_equal(smallerAry, imageIO.imageStructToArray(smallerImg)) + # Test that resize with the same size is a no-op + sameImage = imageIO.createResizeImageUDF((imgAsRow.height, imgAsRow.width)).func(imgAsRow) + self.assertEqual(imgAsRow, sameImage) + # Test that we have a valid image schema (all fields are in) + for n in ImageSchema.imageSchema['image'].dataType.names: + smallerImg[n] + def test_imageConversions(self): + """" + Test conversion image array <-> image struct + """ + def _test(array): + height, width, chan = array.shape + imgAsStruct = imageIO.imageArrayToStruct(array) + self.assertEqual(imgAsStruct.height, height) + self.assertEqual(imgAsStruct.width, width) + self.assertEqual(imgAsStruct.data, array.tobytes()) + imgReconstructed = imageIO.imageStructToArray(imgAsStruct) + np.testing.assert_array_equal(array, imgReconstructed) + for nChannels in (1, 3, 4): + # unsigned bytes + _test(np.random.randint(0, 256, (10, 11, nChannels), 'uint8')) + _test(np.random.random_sample((10, 11, nChannels)).astype('float32')) + + # read images now part of spark, no need to test it here def test_readImages(self): # Test that reading - imageDF = imageIO._readImages("some/path", 2, self.binaryFilesMock) + imageDF = imageIO._readImagesWithCustomFn( + "file/path", decode_f=imageIO.PIL_decode, numPartition=2, sc=self.binaryFilesMock) self.assertTrue("image" in imageDF.schema.names) - self.assertTrue("filePath" in imageDF.schema.names) # The DF should have 2 images and 1 null. self.assertEqual(imageDF.count(), 3) @@ -146,19 +129,20 @@ def test_readImages(self): img = validImages.first().image self.assertEqual(img.height, array.shape[0]) self.assertEqual(img.width, array.shape[1]) - self.assertEqual(imageIO.imageType(img).nChannels, array.shape[2]) + self.assertEqual(imageIO.imageTypeByOrdinal(img.mode).nChannels, array.shape[2]) + # array comes out of PIL and is in RGB order self.assertEqual(img.data, array.tobytes()) def test_udf_schema(self): # Test that utility functions can be used to create a udf that accepts and return # imageSchema def do_nothing(imgRow): - imType = imageIO.imageType(imgRow) array = imageIO.imageStructToArray(imgRow) - return imageIO.imageArrayToStruct(array, imType.sparkMode) - do_nothing_udf = udf(do_nothing, imageIO.imageSchema) + return imageIO.imageArrayToStruct(array) + do_nothing_udf = udf(do_nothing, ImageSchema.imageSchema['image'].dataType) - df = imageIO._readImages("path", 2, self.binaryFilesMock) + df = imageIO._readImagesWithCustomFn( + "file/path", decode_f=imageIO.PIL_decode, numPartition=2, sc=self.binaryFilesMock) df = df.filter(col('image').isNotNull()).withColumn("test", do_nothing_udf('image')) self.assertEqual(df.first().test.data, array.tobytes()) df.printSchema() diff --git a/python/tests/param/params_test.py b/python/tests/param/params_test.py index f7479385..532aaed1 100644 --- a/python/tests/param/params_test.py +++ b/python/tests/param/params_test.py @@ -40,6 +40,7 @@ description='tensor name required'), ] + class ParamsConverterTest(PythonUnitTestCase): """ Test MLlib Params introduced in Spark Deep Learning Pipeline diff --git a/python/tests/tests.py b/python/tests/tests.py index d5520d8b..f9150521 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -32,6 +32,7 @@ from pyspark.sql import SQLContext from pyspark.sql import SparkSession + class PythonUnitTestCase(unittest.TestCase): # We try to use unittest2 for python 2.6 or earlier # This class is created to avoid replicating this logic in various places. @@ -70,12 +71,11 @@ class SparkDLTestCase(TestSparkContext, unittest.TestCase): def setUpClass(cls): cls.setup_env() - @classmethod def tearDownClass(cls): cls.tear_down_env() - def assertDfHasCols(self, df, cols = []): + def assertDfHasCols(self, df, cols=[]): map(lambda c: self.assertIn(c, df.columns), cols) diff --git a/python/tests/transformers/image_utils.py b/python/tests/transformers/image_utils.py index 364fd124..d012202f 100644 --- a/python/tests/transformers/image_utils.py +++ b/python/tests/transformers/image_utils.py @@ -14,11 +14,12 @@ # import os -from glob import glob import tempfile import unittest +from glob import glob from warnings import warn + from keras.applications import InceptionV3 from keras.applications.inception_v3 import preprocess_input, decode_predictions from keras.preprocessing.image import img_to_array, load_img @@ -26,6 +27,7 @@ import numpy as np import PIL.Image + from pyspark.sql.types import StringType from sparkdl.image import imageIO @@ -38,8 +40,12 @@ def _getSampleJPEGDir(): cur_dir = os.path.dirname(__file__) return os.path.join(cur_dir, "../resources/images") +def getImageFiles(): + return glob(os.path.join(_getSampleJPEGDir(), "*")) + def getSampleImageDF(): - return imageIO.readImages(_getSampleJPEGDir()) + return imageIO.readImagesWithCustomFn(path=_getSampleJPEGDir(), decode_f=imageIO.PIL_decode) + def getSampleImagePaths(): dirpath = _getSampleJPEGDir() @@ -47,6 +53,7 @@ def getSampleImagePaths(): if f.endswith('.jpg')] return files + def getSampleImagePathsDF(sqlContext, colName): files = getSampleImagePaths() return sqlContext.createDataFrame(files, StringType()).toDF(colName) @@ -54,16 +61,16 @@ def getSampleImagePathsDF(sqlContext, colName): # Methods for making comparisons between outputs of using different frameworks. # For ImageNet. + class ImageNetOutputComparisonTestCase(unittest.TestCase): - def transformOutputToComparables(self, collected, uri_col, output_col): + def transformOutputToComparables(self, collected, output_col, get_uri): values = {} topK = {} for row in collected: - uri = row[uri_col] + uri = get_uri(row) predictions = row[output_col] self.assertEqual(len(predictions), ImageNetConstants.NUM_CLASSES) - values[uri] = np.expand_dims(predictions, axis=0) topK[uri] = decode_predictions(values[uri], top=5)[0] return values, topK @@ -92,20 +99,6 @@ def compareClassSets(self, preds1, preds2): self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]])) -def getSampleImageList(): - imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*")) - images = [] - for f in imageFiles: - try: - img = PIL.Image.open(f) - except IOError: - warn("Could not read file in image directory.") - images.append(None) - else: - images.append(img) - return imageFiles, images - - def executeKerasInceptionV3(image_df, uri_col="filePath"): """ Apply Keras InceptionV3 Model on input DataFrame. @@ -127,6 +120,7 @@ def executeKerasInceptionV3(image_df, uri_col="filePath"): topK[raw_uri] = decode_predictions(values[raw_uri], top=5)[0] return values, topK + def loadAndPreprocessKerasInceptionV3(raw_uri): # this is the canonical way to load and prep images in keras uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri @@ -134,6 +128,7 @@ def loadAndPreprocessKerasInceptionV3(raw_uri): image = np.expand_dims(image, axis=0) return preprocess_input(image) + def prepInceptionV3KerasModelFile(fileName): model_dir_tmp = tempfile.mkdtemp("sparkdl_keras_tests", dir="/tmp") path = model_dir_tmp + "/" + fileName diff --git a/python/tests/transformers/keras_image_test.py b/python/tests/transformers/keras_image_test.py index 397c780f..d215d991 100644 --- a/python/tests/transformers/keras_image_test.py +++ b/python/tests/transformers/keras_image_test.py @@ -60,7 +60,8 @@ def test_inceptionV3_vs_keras(self): self.assertEqual(len(final_df.columns), 2) collected = final_df.collect() - tvals, ttopK = self.transformOutputToComparables(collected, input_col, output_col) + tvals, ttopK = self.transformOutputToComparables( + collected, output_col, lambda row: row["uri"]) kvals, ktopK = image_utils.executeKerasInceptionV3(uri_df, uri_col=input_col) self.compareClassSets(ktopK, ttopK) diff --git a/python/tests/transformers/keras_transformer_test.py b/python/tests/transformers/keras_transformer_test.py index 3d867f26..c8869e92 100644 --- a/python/tests/transformers/keras_transformer_test.py +++ b/python/tests/transformers/keras_transformer_test.py @@ -63,13 +63,13 @@ def _test_keras_transformer_helper(self, model, model_filename): id_col = "id" # Create Keras model, persist it to disk, and create KerasTransformer - save_filename = "%s.h5"%(model_filename) + save_filename = "%s.h5" % (model_filename) model_path = self._writeKerasModelFile(model, save_filename) transformer = KerasTransformer(inputCol=input_col, outputCol=output_col, modelFile=model_path) # Load dataset, transform it with KerasTransformer - input_shape = list(model.input_shape[1:]) # Get shape of a single example + input_shape = list(model.input_shape[1:]) # Get shape of a single example df = self._getInputDF(self.sql, inputShape=input_shape, inputCol=input_col, idCol=id_col) final_df = transformer.transform(df) sparkdl_predictions = self._convertOutputToComparables(final_df, id_col, output_col) @@ -85,7 +85,7 @@ def _test_keras_transformer_helper(self, model, model_filename): diff_tolerance = 1e-5 assert np.allclose(sparkdl_predictions, keras_predictions, atol=diff_tolerance), "" \ "KerasTransformer output differed (absolute difference) from Keras model output by " \ - "as much as %s, maximum allowed deviation = %s"%(max_pred_diff, diff_tolerance) + "as much as %s, maximum allowed deviation = %s" % (max_pred_diff, diff_tolerance) def _getKerasModelWeightInitializer(self): """ @@ -105,7 +105,7 @@ def _createNumpyData(self, num_examples, example_shape): def _getInputDF(self, sqlContext, inputShape, inputCol, idCol): """ Return a DataFrame containing a long ID column and an input column of arrays. """ x_train = self._createNumpyData(num_examples=20, example_shape=inputShape) - train_rows = [{idCol : i, inputCol : x_train[i].tolist()} for i in range(len(x_train))] + train_rows = [{idCol: i, inputCol: x_train[i].tolist()} for i in range(len(x_train))] return sqlContext.createDataFrame(train_rows) def _writeKerasModelFile(self, model, filename): diff --git a/python/tests/transformers/named_image_InceptionV3_test.py b/python/tests/transformers/named_image_InceptionV3_test.py index 91f53bff..22594062 100644 --- a/python/tests/transformers/named_image_InceptionV3_test.py +++ b/python/tests/transformers/named_image_InceptionV3_test.py @@ -15,6 +15,7 @@ from .named_image_test import NamedImageTransformerBaseTestCase + class NamedImageTransformerInceptionV3Test(NamedImageTransformerBaseTestCase): __test__ = True diff --git a/python/tests/transformers/named_image_ResNet50_test.py b/python/tests/transformers/named_image_ResNet50_test.py index 390bf2a2..37653c02 100644 --- a/python/tests/transformers/named_image_ResNet50_test.py +++ b/python/tests/transformers/named_image_ResNet50_test.py @@ -15,6 +15,7 @@ from .named_image_test import NamedImageTransformerBaseTestCase + class NamedImageTransformerResNet50Test(NamedImageTransformerBaseTestCase): __test__ = True diff --git a/python/tests/transformers/named_image_VGG16_test.py b/python/tests/transformers/named_image_VGG16_test.py index f050e994..81d60018 100644 --- a/python/tests/transformers/named_image_VGG16_test.py +++ b/python/tests/transformers/named_image_VGG16_test.py @@ -20,4 +20,4 @@ class NamedImageTransformerVGG16Test(NamedImageTransformerBaseTestCase): __test__ = os.getenv('RUN_ONLY_LIGHT_TESTS', False) != "True" name = "VGG16" - numPartitionsOverride = 1 # hits OOM if more than 2 threads + numPartitionsOverride = 1 # hits OOM if more than 2 threads diff --git a/python/tests/transformers/named_image_VGG19_test.py b/python/tests/transformers/named_image_VGG19_test.py index 38f4736f..f5d0b381 100644 --- a/python/tests/transformers/named_image_VGG19_test.py +++ b/python/tests/transformers/named_image_VGG19_test.py @@ -16,7 +16,8 @@ import os from .named_image_test import NamedImageTransformerBaseTestCase + class NamedImageTransformerVGG19Test(NamedImageTransformerBaseTestCase): __test__ = os.getenv('RUN_ONLY_LIGHT_TESTS', False) != "True" name = "VGG19" - numPartitionsOverride = 1 # hits OOM if more than 2 threads \ No newline at end of file + numPartitionsOverride = 1 # hits OOM if more than 2 threads diff --git a/python/tests/transformers/named_image_Xception_test.py b/python/tests/transformers/named_image_Xception_test.py index c76d2147..4b56e0a0 100644 --- a/python/tests/transformers/named_image_Xception_test.py +++ b/python/tests/transformers/named_image_Xception_test.py @@ -15,6 +15,7 @@ from .named_image_test import NamedImageTransformerBaseTestCase + class NamedImageTransformerXceptionTest(NamedImageTransformerBaseTestCase): __test__ = True diff --git a/python/tests/transformers/named_image_test.py b/python/tests/transformers/named_image_test.py index 2cd265fd..110be520 100644 --- a/python/tests/transformers/named_image_test.py +++ b/python/tests/transformers/named_image_test.py @@ -14,6 +14,11 @@ # import numpy as np +import os + +from glob import glob +from PIL import Image + from keras.applications import resnet50 import tensorflow as tf @@ -26,8 +31,13 @@ import sparkdl.transformers.keras_applications as keras_apps from sparkdl.transformers.named_image import (DeepImagePredictor, DeepImageFeaturizer, _buildTFGraphForName) + +from sparkdl.image.image import ImageSchema + from ..tests import SparkDLTestCase -from .image_utils import getSampleImageDF, getSampleImageList +from .image_utils import getSampleImageDF +from.image_utils import getImageFiles + class KerasApplicationModelTestCase(SparkDLTestCase): @@ -63,30 +73,41 @@ class NamedImageTransformerBaseTestCase(SparkDLTestCase): # Allow subclasses to force number of partitions - a hack to avoid OOM issues numPartitionsOverride = None + @classmethod + def getSampleImageList(cls): + shape = cls.appModel.inputShape() + imageFiles = getImageFiles() + images = [imageIO.PIL_to_imageStruct(Image.open(f).resize(shape)) for f in imageFiles] + return imageFiles, np.array(images) + @classmethod def setUpClass(cls): super(NamedImageTransformerBaseTestCase, cls).setUpClass() - cls.appModel = keras_apps.getKerasApplicationModel(cls.name) - shape = cls.appModel.inputShape() - - imgFiles, images = getSampleImageList() - imageArray = np.empty((len(images), shape[0], shape[1], 3), 'uint8') - for i, img in enumerate(images): - assert img is not None and img.mode == "RGB" - imageArray[i] = np.array(img.resize(shape)) + imgFiles, imageArray = cls.getSampleImageList() cls.imageArray = imageArray - + cls.imgFiles = imgFiles + cls.fileOrder = {imgFiles[i].split('/')[-1]: i for i in range(len(imgFiles))} # Predict the class probabilities for the images in our test library using keras API # and cache for use by multiple tests. preppedImage = cls.appModel._testPreprocess(imageArray.astype('float32')) - cls.kerasPredict = cls.appModel._testKerasModel(include_top=True).predict(preppedImage) + cls.preppedImage = preppedImage + cls.kerasPredict = cls.appModel._testKerasModel( + include_top=True).predict(preppedImage, batch_size=1) cls.kerasFeatures = cls.appModel._testKerasModel(include_top=False).predict(preppedImage) cls.imageDF = getSampleImageDF().limit(5) if(cls.numPartitionsOverride): cls.imageDf = cls.imageDF.coalesce(cls.numPartitionsOverride) + def _sortByFileOrder(self, ary): + """ + This is to ensure we are comparing compatible sequences of predictions. + Sorts the results according to the order in which the files have been read by python. + Note: Java and python can read files in different order. + """ + fileOrder = self.fileOrder + return sorted(ary, key=lambda x: fileOrder[x['image']['origin'].split('/')[-1]]) def test_buildtfgraphforname(self): """" @@ -112,15 +133,16 @@ def test_DeepImagePredictorNoReshape(self): """ imageArray = self.imageArray kerasPredict = self.kerasPredict + def rowWithImage(img): # return [imageIO.imageArrayToStruct(img.astype('uint8'), imageType.sparkMode)] - row = imageIO.imageArrayToStruct(img.astype('uint8'), imageIO.SparkMode.RGB) + row = imageIO.imageArrayToStruct(img.astype('uint8')) # re-order row to avoid pyspark bug - return [[getattr(row, field.name) for field in imageIO.imageSchema]] + return [[getattr(row, field.name) for field in ImageSchema.imageSchema['image'].dataType]] # test: predictor vs keras on resized images rdd = self.sc.parallelize([rowWithImage(img) for img in imageArray]) - dfType = StructType([StructField("image", imageIO.imageSchema)]) + dfType = ImageSchema.imageSchema imageDf = rdd.toDF(dfType) if self.numPartitionsOverride: imageDf = imageDf.coalesce(self.numPartitionsOverride) @@ -140,9 +162,8 @@ def test_DeepImagePredictor(self): kerasPredict = self.kerasPredict transformer = DeepImagePredictor(inputCol='image', modelName=self.name, outputCol="prediction",) - fullPredict = transformer.transform(self.imageDF).collect() + fullPredict = self._sortByFileOrder(transformer.transform(self.imageDF).collect()) fullPredict = np.array([i.prediction for i in fullPredict]) - self.assertEqual(kerasPredict.shape, fullPredict.shape) np.testing.assert_array_almost_equal(kerasPredict, fullPredict, decimal=6) @@ -171,7 +192,7 @@ def test_featurization(self): transformer = DeepImageFeaturizer(inputCol="image", outputCol=output_col, modelName=self.name) transformed_df = transformer.transform(self.imageDF) - collected = transformed_df.collect() + collected = self._sortByFileOrder(transformed_df.collect()) features = np.array([i.prediction for i in collected]) # Note: keras features may be multi-dimensional np arrays, but transformer features @@ -193,7 +214,7 @@ def test_featurizer_in_pipeline(self): # add arbitrary labels to run logistic regression # TODO: it's weird that the test fails on some combinations of labels. check why. label_udf = udf(lambda x: abs(hash(x)) % 2, IntegerType()) - train_df = self.imageDF.withColumn("label", label_udf(self.imageDF["filePath"])) + train_df = self.imageDF.withColumn("label", label_udf(self.imageDF["image"]["origin"])) lrModel = pipeline.fit(train_df) # see if we at least get the training examples right. diff --git a/python/tests/transformers/tf_image_test.py b/python/tests/transformers/tf_image_test.py index ed495a73..f1a346f2 100644 --- a/python/tests/transformers/tf_image_test.py +++ b/python/tests/transformers/tf_image_test.py @@ -22,6 +22,7 @@ import sparkdl.graph.utils as tfx from sparkdl.image.imageIO import imageStructToArray +from sparkdl.image import imageIO from sparkdl.transformers.keras_utils import KSessionWrap from sparkdl.transformers.tf_image import TFImageTransformer import sparkdl.transformers.utils as utils @@ -36,7 +37,6 @@ class TFImageTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTe # Test loading & pre-processing as an example of a simple graph # NOTE: resizing here/tensorflow and in keras workflow are different, so the # test would fail with resizing added in. - def _loadImageViaKeras(self, raw_uri): uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri image = img_to_array(load_img(uri)) @@ -47,10 +47,11 @@ def test_load_image_vs_keras(self): g = tf.Graph() with g.as_default(): image_arr = utils.imageInputPlaceholder() - preprocessed = preprocess_input(image_arr) + # keras expects array in RGB order, we get it from image schema in BGR => need to flip + preprocessed = preprocess_input(imageIO._reverseChannels(image_arr)) output_col = "transformed_image" - transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=g, + transformer = TFImageTransformer(channelOrder='BGR', inputCol="image", outputCol=output_col, graph=g, inputTensor=image_arr, outputTensor=preprocessed.name, outputMode="vector") @@ -60,12 +61,35 @@ def test_load_image_vs_keras(self): for row in df.collect(): processed = np.array(row[output_col]).astype(np.float32) # compare to keras loading - images = self._loadImageViaKeras(row["filePath"]) + images = self._loadImageViaKeras(row["image"]['origin']) image = images[0] image.shape = (1, image.shape[0] * image.shape[1] * image.shape[2]) keras_processed = image[0] - self.assertTrue( (processed == keras_processed).all() ) + self.assertTrue((processed == keras_processed).all()) + def test_load_image_vs_keras_RGB(self): + g = tf.Graph() + with g.as_default(): + image_arr = utils.imageInputPlaceholder() + # keras expects array in RGB order, we get it from image schema in BGR => need to flip + preprocessed = preprocess_input(image_arr) + + output_col = "transformed_image" + transformer = TFImageTransformer(channelOrder='RGB', inputCol="image", outputCol=output_col, graph=g, + inputTensor=image_arr, outputTensor=preprocessed.name, + outputMode="vector") + + image_df = image_utils.getSampleImageDF() + df = transformer.transform(image_df.limit(5)) + + for row in df.collect(): + processed = np.array(row[output_col]).astype(np.float32) + # compare to keras loading + images = self._loadImageViaKeras(row["image"]['origin']) + image = images[0] + image.shape = (1, image.shape[0] * image.shape[1] * image.shape[2]) + keras_processed = image[0] + self.assertTrue((processed == keras_processed).all()) # Test full pre-processing for InceptionV3 as an example of a simple computation graph @@ -74,11 +98,12 @@ def _preprocessingInceptionV3Transformed(self, outputMode, outputCol): with g.as_default(): image_arr = utils.imageInputPlaceholder() resized_images = tf.image.resize_images(image_arr, InceptionV3Constants.INPUT_SHAPE) - processed_images = preprocess_input(resized_images) + # keras expects array in RGB order, we get it from image schema in BGR => need to flip + processed_images = preprocess_input(imageIO._reverseChannels(resized_images)) self.assertEqual(processed_images.shape[1], InceptionV3Constants.INPUT_SHAPE[0]) self.assertEqual(processed_images.shape[2], InceptionV3Constants.INPUT_SHAPE[1]) - transformer = TFImageTransformer(inputCol="image", outputCol=outputCol, graph=g, + transformer = TFImageTransformer(channelOrder='BGR', inputCol="image", outputCol=outputCol, graph=g, inputTensor=image_arr.name, outputTensor=processed_images, outputMode=outputMode) image_df = image_utils.getSampleImageDF() @@ -99,11 +124,10 @@ def test_image_output(self): # TODO: add tests for non-RGB8 images, at least RGB-float32. - # Test InceptionV3 prediction as an example of applying a trained model. def _executeTensorflow(self, graph, input_tensor_name, output_tensor_name, - df, id_col="filePath", input_col="image"): + df, input_col="image"): with tf.Session(graph=graph) as sess: output_tensor = graph.get_tensor_by_name(output_tensor_name) image_collected = df.collect() @@ -111,11 +135,11 @@ def _executeTensorflow(self, graph, input_tensor_name, output_tensor_name, topK = {} for img_row in image_collected: image = np.expand_dims(imageStructToArray(img_row[input_col]), axis=0) - uri = img_row[id_col] + uri = img_row['image']['origin'] output = sess.run([output_tensor], feed_dict={ graph.get_tensor_by_name(input_tensor_name): image - }) + }) values[uri] = np.array(output[0]) topK[uri] = decode_predictions(values[uri], top=5)[0] return values, topK @@ -129,22 +153,22 @@ def test_prediction_vs_tensorflow_inceptionV3(self): with g.as_default(): K.set_learning_phase(0) # this is important but it's on the user to call it. # nChannels needed for input_tensor in the InceptionV3 call below - image_string = utils.imageInputPlaceholder(nChannels = 3) + image_string = utils.imageInputPlaceholder(nChannels=3) resized_images = tf.image.resize_images(image_string, InceptionV3Constants.INPUT_SHAPE) - preprocessed = preprocess_input(resized_images) + # keras expects array in RGB order, we get it from image schema in BGR => need to flip + preprocessed = preprocess_input(imageIO._reverseChannels(resized_images)) model = InceptionV3(input_tensor=preprocessed, weights="imagenet") graph = tfx.strip_and_freeze_until([model.output], g, sess, return_graph=True) - transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=graph, + transformer = TFImageTransformer(channelOrder='BGR', inputCol="image", outputCol=output_col, graph=graph, inputTensor=image_string, outputTensor=model.output, outputMode="vector") transformed_df = transformer.transform(image_df.limit(10)) self.assertDfHasCols(transformed_df, [output_col]) collected = transformed_df.collect() transformer_values, transformer_topK = self.transformOutputToComparables(collected, - "filePath", - output_col) + output_col, lambda row: row['image']['origin']) tf_values, tf_topK = self._executeTensorflow(graph, image_string.name, model.output.name, image_df) diff --git a/python/tests/transformers/tf_transformer_test.py b/python/tests/transformers/tf_transformer_test.py index 849a84d7..0f18af41 100644 --- a/python/tests/transformers/tf_transformer_test.py +++ b/python/tests/transformers/tf_transformer_test.py @@ -27,6 +27,7 @@ from ..tests import SparkDLTestCase + class TFTransformerTests(SparkDLTestCase): def test_graph_novar(self): transformer = _build_transformer(lambda session: @@ -79,6 +80,7 @@ def _build_graph(sess): x = tf.placeholder(tf.float64, shape=[None, _tensor_size], name=_tensor_input_name) _ = tf.reduce_max(x, axis=1, name=_tensor_output_name) + def _build_local_features(): """ Build numpy array (i.e. local) features. @@ -95,6 +97,7 @@ def _build_local_features(): return local_features + def _get_expected_result(gin, local_features): """ Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results. @@ -123,6 +126,7 @@ def _get_expected_result(gin, local_features): return expected + def _check_transformer_output(transformer, dataset, expected): """ Given a transformer and a spark dataset, check if the transformer diff --git a/python/tests/udf/keras_sql_udf_test.py b/python/tests/udf/keras_sql_udf_test.py index 5c67c854..9743e940 100644 --- a/python/tests/udf/keras_sql_udf_test.py +++ b/python/tests/udf/keras_sql_udf_test.py @@ -27,21 +27,25 @@ from pyspark import SparkContext from pyspark.sql import DataFrame, Row from pyspark.sql.functions import udf +from sparkdl.image.image import ImageSchema from sparkdl.graph.builder import IsolatedSession from sparkdl.graph.tensorframes_udf import makeGraphUDF import sparkdl.graph.utils as tfx from sparkdl.udf.keras_image_model import registerKerasImageUDF from sparkdl.utils import jvmapi as JVMAPI -from sparkdl.image.imageIO import imageSchema, imageArrayToStruct +from sparkdl.image.imageIO import imageArrayToStruct +from sparkdl.image.imageIO import _reverseChannels from ..tests import SparkDLTestCase from ..transformers.image_utils import getSampleImagePathsDF + def get_image_paths_df(sqlCtx): df = getSampleImagePathsDF(sqlCtx, "fpath") df.createOrReplaceTempView("_test_image_paths_df") return df + class SqlUserDefinedFunctionTest(SparkDLTestCase): def _assert_function_exists(self, fh_name): @@ -55,7 +59,7 @@ def test_simple_keras_udf(self): # The leading batch size is taken care of by Keras with IsolatedSession(using_keras=True) as issn: model = Sequential() - model.add(Flatten(input_shape=(640,480,3))) + model.add(Flatten(input_shape=(640, 480, 3))) model.add(Dense(units=64)) model.add(Activation('relu')) model.add(Dense(units=10)) @@ -98,14 +102,18 @@ def pil_load_spimg(fpath): from PIL import Image import numpy as np img_arr = np.array(Image.open(fpath), dtype=np.uint8) - return imageArrayToStruct(img_arr) + # PIL is RGB, image schema is BGR => need to flip the channels + return imageArrayToStruct(_reverseChannels(img_arr)) def keras_load_spimg(fpath): - return imageArrayToStruct(keras_load_img(fpath)) + # Keras loads image in RGB order, ImageSchema expects BGR => need to flip + return imageArrayToStruct(_reverseChannels(keras_load_img(fpath))) # Load image with Keras and store it in our image schema - JVMAPI.registerUDF('keras_load_spimg', keras_load_spimg, imageSchema) - JVMAPI.registerUDF('pil_load_spimg', pil_load_spimg, imageSchema) + JVMAPI.registerUDF('keras_load_spimg', keras_load_spimg, + ImageSchema.imageSchema['image'].dataType) + JVMAPI.registerUDF('pil_load_spimg', pil_load_spimg, + ImageSchema.imageSchema['image'].dataType) # Register an InceptionV3 model registerKerasImageUDF("iv3_img_pred", @@ -150,7 +158,6 @@ def test_map_rows_sql_1(self): data2 = df2.collect() assert data2[0].z == 3.0, data2 - def test_map_blocks_sql_1(self): data = [Row(x=float(x)) for x in range(5)] df = self.sql.createDataFrame(data) diff --git a/python/tests/utils/test_python_interface.py b/python/tests/utils/test_python_interface.py index f3f411c1..906a5be2 100644 --- a/python/tests/utils/test_python_interface.py +++ b/python/tests/utils/test_python_interface.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import sys, traceback +import sys +import traceback from pyspark import SparkContext, SQLContext from pyspark.sql.column import Column from sparkdl.utils import jvmapi as JVMAPI from ..tests import SparkDLTestCase + class PythonAPITest(SparkDLTestCase): def test_using_api(self): diff --git a/src/main/scala/com/databricks/sparkdl/DeepImageFeaturizer.scala b/src/main/scala/com/databricks/sparkdl/DeepImageFeaturizer.scala index 3c032243..034463e9 100644 --- a/src/main/scala/com/databricks/sparkdl/DeepImageFeaturizer.scala +++ b/src/main/scala/com/databricks/sparkdl/DeepImageFeaturizer.scala @@ -18,7 +18,7 @@ package com.databricks.sparkdl import java.nio.file.Paths -import org.apache.spark.image.ImageSchema +import org.apache.spark.ml.image.ImageSchema import org.apache.spark.ml.Transformer import org.apache.spark.ml.linalg.SQLDataTypes.VectorType import org.apache.spark.ml.linalg.Vectors diff --git a/src/main/scala/com/databricks/sparkdl/ImageUtils.scala b/src/main/scala/com/databricks/sparkdl/ImageUtils.scala index b0362dcf..2e660ef6 100644 --- a/src/main/scala/com/databricks/sparkdl/ImageUtils.scala +++ b/src/main/scala/com/databricks/sparkdl/ImageUtils.scala @@ -19,8 +19,10 @@ package com.databricks.sparkdl import java.awt.image.BufferedImage import java.awt.{Color, Image} -import org.apache.spark.image.ImageSchema +import org.apache.spark.ml.image.ImageSchema import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.udf private[sparkdl] object ImageUtils { @@ -79,7 +81,7 @@ private[sparkdl] object ImageUtils { * @param image Java BufferedImage. * @return Row image in spark.ml.image format with 3 channels in BGR order. */ - private[sparkdl] def spImageFromBufferedImage(image: BufferedImage): Row = { + private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin: String = null): Row = { val channels = 3 val height = image.getHeight val width = image.getWidth @@ -98,8 +100,7 @@ private[sparkdl] object ImageUtils { } h += 1 } - // TODO: udpate mode to be Int when spark.ml.image is merged. - Row(null, height, width, channels, "CV_U8C3", decoded) + Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded) } /** @@ -137,8 +138,7 @@ private[sparkdl] object ImageUtils { val graphic = tgtImg.createGraphics() graphic.drawImage(scaledImg, 0, 0, null) graphic.dispose() - - spImageFromBufferedImage(tgtImg) + spImageFromBufferedImage(tgtImg, origin=ImageSchema.getOrigin(spImage)) } } } diff --git a/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala new file mode 100644 index 00000000..2d5c3327 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -0,0 +1,119 @@ +// NOTE: This file is copied from Spark2.3 in order to be able to use this in allready released spark versions. +// TODO: remove this when Spark 2.3 is out! + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import scala.language.existentials +import scala.util.Random + +import org.apache.commons.io.FilenameUtils +import org.apache.hadoop.conf.{Configuration, Configured} +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat + +import org.apache.spark.sql.SparkSession + +private object RecursiveFlag { + /** + * Sets the spark recursive flag and then restores it. + * + * @param value Value to set + * @param spark Existing spark session + * @param f The function to evaluate after setting the flag + * @return Returns the evaluation result T of the function + */ + def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { + val flagName = FileInputFormat.INPUT_DIR_RECURSIVE + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.get(flagName)) + hadoopConf.set(flagName, value.toString) + try f finally { + old match { + case Some(v) => hadoopConf.set(flagName, v) + case None => hadoopConf.unset(flagName) + } + } + } +} + +/** + * Filter that allows loading a fraction of HDFS files. + */ +private class SamplePathFilter extends Configured with PathFilter { + val random = new Random() + + // Ratio of files to be read from disk + var sampleRatio: Double = 1 + + override def setConf(conf: Configuration): Unit = { + if (conf != null) { + sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) + val seed = conf.getLong(SamplePathFilter.seedParam, 0) + random.setSeed(seed) + } + } + + override def accept(path: Path): Boolean = { + // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead + !SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio + } +} + +private object SamplePathFilter { + val ratioParam = "sampleRatio" + val seedParam = "seed" + + def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" + + /** + * Sets the HDFS PathFilter flag and then restores it. + * Only applies the filter if sampleRatio is less than 1. + * + * @param sampleRatio Fraction of the files that the filter picks + * @param spark Existing Spark session + * @param seed Random number seed + * @param f The function to evaluate after setting the flag + * @return Returns the evaluation result T of the function + */ + def withPathFilter[T]( + sampleRatio: Double, + spark: SparkSession, + seed: Long)(f: => T): T = { + val sampleImages = sampleRatio < 1 + if (sampleImages) { + val flagName = FileInputFormat.PATHFILTER_CLASS + val hadoopConf = spark.sparkContext.hadoopConfiguration + val old = Option(hadoopConf.getClass(flagName, null)) + hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) + hadoopConf.setLong(SamplePathFilter.seedParam, seed) + hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter]) + try f finally { + hadoopConf.unset(SamplePathFilter.ratioParam) + hadoopConf.unset(SamplePathFilter.seedParam) + old match { + case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) + case None => hadoopConf.unset(flagName) + } + } + } else { + f + } + } +} diff --git a/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala new file mode 100644 index 00000000..9ebce2ad --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -0,0 +1,261 @@ +package org.apache.spark.ml.image +// NOTE: This file is copied from Spark2.3 in order to be able to use this in allready released spark versions. +// TODO: remove this when Spark 2.3 is out! + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import java.awt.Color +import java.awt.color.ColorSpace +import java.io.ByteArrayInputStream +import javax.imageio.ImageIO + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Defines the image schema and methods to read and manipulate images. + */ +@Experimental +@Since("2.3.0") +object ImageSchema { + + val undefinedImageType = "Undefined" + + /** + * (Scala-specific) OpenCV type mapping supported + */ + val ocvTypes: Map[String, Int] = Map( + undefinedImageType -> -1, + "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 + ) + + /** + * (Java-specific) OpenCV type mapping supported + */ + val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava + + /** + * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) + */ + val columnSchema = StructType( + StructField("origin", StringType, true) :: + StructField("height", IntegerType, false) :: + StructField("width", IntegerType, false) :: + StructField("nChannels", IntegerType, false) :: + // OpenCV-compatible type: CV_8UC3 in most cases + StructField("mode", IntegerType, false) :: + // Bytes in OpenCV-compatible order: row-wise BGR in most cases + StructField("data", BinaryType, false) :: Nil) + + val imageFields: Array[String] = columnSchema.fieldNames + + /** + * DataFrame with a single column of images named "image" (nullable) + */ + val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) + + /** + * Gets the origin of the image + * + * @return The origin of the image + */ + def getOrigin(row: Row): String = row.getString(0) + + /** + * Gets the height of the image + * + * @return The height of the image + */ + def getHeight(row: Row): Int = row.getInt(1) + + /** + * Gets the width of the image + * + * @return The width of the image + */ + def getWidth(row: Row): Int = row.getInt(2) + + /** + * Gets the number of channels in the image + * + * @return The number of channels in the image + */ + def getNChannels(row: Row): Int = row.getInt(3) + + /** + * Gets the OpenCV representation as an int + * + * @return The OpenCV representation as an int + */ + def getMode(row: Row): Int = row.getInt(4) + + /** + * Gets the image data + * + * @return The image data + */ + def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) + + /** + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ + private[spark] def invalidImageRow(origin: String): Row = + Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) + + /** + * Convert the compressed image (jpeg, png, etc.) into OpenCV + * representation and store it in DataFrame Row + * + * @param origin Arbitrary string that identifies the image + * @param bytes Image bytes (for example, jpeg) + * @return DataFrame Row or None (if the decompression fails) + */ + private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { + + val img = ImageIO.read(new ByteArrayInputStream(bytes)) + + if (img == null) { + None + } else { + val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY + val hasAlpha = img.getColorModel.hasAlpha + + val height = img.getHeight + val width = img.getWidth + val (nChannels, mode) = if (isGray) { + (1, ocvTypes("CV_8UC1")) + } else if (hasAlpha) { + (4, ocvTypes("CV_8UC4")) + } else { + (3, ocvTypes("CV_8UC3")) + } + + val imageSize = height * width * nChannels + assert(imageSize < 1e9, "image is too large") + val decoded = Array.ofDim[Byte](imageSize) + + // Grayscale images in Java require special handling to get the correct intensity + if (isGray) { + var offset = 0 + val raster = img.getRaster + for (h <- 0 until height) { + for (w <- 0 until width) { + decoded(offset) = raster.getSample(w, h, 0).toByte + offset += 1 + } + } + } else { + var offset = 0 + for (h <- 0 until height) { + for (w <- 0 until width) { + val color = new Color(img.getRGB(w, h)) + + decoded(offset) = color.getBlue.toByte + decoded(offset + 1) = color.getGreen.toByte + decoded(offset + 2) = color.getRed.toByte + if (nChannels == 4) { + decoded(offset + 3) = color.getAlpha.toByte + } + offset += nChannels + } + } + } + + // the internal "Row" is needed, because the image is a single DataFrame column + Some(Row(Row(origin, height, width, nChannels, mode, decoded))) + } + } + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) + + /** + * Read the directory of images from the local or remote source + * + * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, + * there may be a race condition where one job overwrites the hadoop configs of another. + * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but + * potentially non-deterministic. + * + * @param path Path to the image directory + * @param sparkSession Spark Session, if omitted gets or creates the session + * @param recursive Recursive path search flag + * @param numPartitions Number of the DataFrame partitions, + * if omitted uses defaultParallelism instead + * @param dropImageFailures Drop the files that are not valid images from the result + * @param sampleRatio Fraction of the files loaded + * @return DataFrame with a single column "image" of images; + * see ImageSchema for the details + */ + def readImages( + path: String, + sparkSession: SparkSession, + recursive: Boolean, + numPartitions: Int, + dropImageFailures: Boolean, + sampleRatio: Double, + seed: Long): DataFrame = { + require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") + + val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate + val partitions = + if (numPartitions > 0) { + numPartitions + } else { + session.sparkContext.defaultParallelism + } + + RecursiveFlag.withRecursiveFlag(recursive, session) { + SamplePathFilter.withPathFilter(sampleRatio, session, seed) { + val binResult = session.sparkContext.binaryFiles(path, partitions) + val streams = if (numPartitions == -1) binResult else binResult.repartition(partitions) + val convert = (origin: String, bytes: PortableDataStream) => + decode(origin, bytes.toArray()) + val images = if (dropImageFailures) { + streams.flatMap { case (origin, bytes) => convert(origin, bytes) } + } else { + streams.map { case (origin, bytes) => + convert(origin, bytes).getOrElse(invalidImageRow(origin)) + } + } + session.createDataFrame(images, imageSchema) + } + } + } +} + diff --git a/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala b/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala index 31f5a9b1..6182723b 100644 --- a/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala +++ b/src/test/scala/com/databricks/sparkdl/DeepImageFeaturizerSuite.scala @@ -16,12 +16,14 @@ package com.databricks.sparkdl -import org.apache.spark.image.ImageSchema -import org.apache.spark.sql.functions.{col, lit} +import org.scalatest.FunSuite + +import org.apache.spark.ml.image.ImageSchema import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{StructField, StructType} -import org.scalatest.FunSuite + class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext with DefaultReadWriteTest { diff --git a/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala b/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala index ccdc82e9..8ec69522 100644 --- a/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala +++ b/src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala @@ -21,8 +21,10 @@ import java.io.File import javax.imageio.ImageIO import scala.util.Random -import org.apache.spark.image.ImageSchema + +import org.apache.spark.ml.image.ImageSchema import org.apache.spark.sql.Row + import org.scalatest.FunSuite object ImageUtilsSuite { @@ -44,6 +46,7 @@ object ImageUtilsSuite { class ImageUtilsSuite extends FunSuite { // We want to make sure to test ImageUtils in headless mode to ensure it'll work on all systems. assert(System.getProperty("java.awt.headless") === "true") + import ImageUtilsSuite._ test("Test spImage resize.") { @@ -63,7 +66,7 @@ class ImageUtilsSuite extends FunSuite { val rand = new Random(971) val imageData = Array.ofDim[Byte](height * width * channels) rand.nextBytes(imageData) - val spImage = Row(null, height, width, channels, "CV_U8C3", imageData) + val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), imageData) val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage) assert(spImage === testImage, "Image changed during conversion.") @@ -81,7 +84,7 @@ class ImageUtilsSuite extends FunSuite { (0 until width).flatMap { j => Seq(x + j + 1, x + j + 4, x + j + 7) } }.map(_.toByte).toArray - val spImage = Row(null, height, width, 3, "CV_U8C3", rawData) + val spImage = Row(null, height, width, 3, ImageSchema.ocvTypes("CV_8UC3"), rawData) val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) for (h <- 0 until height) { diff --git a/src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala b/src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala index 86464876..42ac2ed2 100644 --- a/src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala +++ b/src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala @@ -28,7 +28,7 @@ import com.databricks.sparkdl.TestSparkContext class SparkDLStubSuite extends FunSuite with TestSparkContext { test("Registered UDF must be found") { - val udfName = "sparkdl-test-udf" + val udfName = "sparkdl_test_udf" val udfImpl = { (x: Int, y: Int) => x + y } UDFUtils.registerUDF(spark.sqlContext, udfName, udf(udfImpl)) assert(spark.catalog.functionExists(udfName)) diff --git a/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala b/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala index 6a9dad0e..c36a2922 100644 --- a/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala +++ b/src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala @@ -53,7 +53,7 @@ class SqlOpsSpec extends FunSuite with TestSparkContext with GraphScoping with L Seq("p1", "p2"), Map("a" -> "a")) - val udfName = "tfs-test-simple-add" + val udfName = "tfs_test_simple_add" val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false) UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration assert(spark.catalog.functionExists(udfName))