Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ private[libsvm] class LibSVMFileFormat

override def toString: String = "LibSVM"

private def verifySchema(dataSchema: StructType): Unit = {
private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = {
if (
dataSchema.size != 2 ||
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
!(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
!(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
Expand Down Expand Up @@ -119,7 +119,7 @@ private[libsvm] class LibSVMFileFormat
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
verifySchema(dataSchema, true)
new OutputWriterFactory {
override def newInstance(
path: String,
Expand All @@ -142,7 +142,7 @@ private[libsvm] class LibSVMFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
verifySchema(dataSchema, false)
val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
assert(numFeatures > 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ package org.apache.spark.ml.source.libsvm

import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import java.util.List

import com.google.common.io.Files

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -109,14 +112,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data and read it again") {
val df = spark.read.format("libsvm").load(path)
val tempDir2 = new File(tempDir, "read_write_test")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest the temp dir name to be Identifiable.randomUID("read_write_test"). Avoid conflicts with other parallel running tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Utils.createTempDir

Copy link
Author

@ProtD ProtD Aug 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.createTempDir seems to be a nicer way. The directory is automatically deleted when VM shuts down, so I believe no manual cleanup (cf. comment below) is needed.

val writepath = tempDir2.toURI.toString
val writePath = tempDir2.toURI.toString
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tempDir2.getPath

// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)

val df2 = spark.read.format("libsvm").load(writepath)
val df2 = spark.read.format("libsvm").load(writePath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
Utils.deleteRecursively(tempDir2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this cleanup I think. The test framework will clean temp dir automatically I think.

}

test("write libsvm data failed due to invalid schema") {
Expand All @@ -126,6 +130,29 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("write libsvm data from scratch and read it again") {
val rawData = new java.util.ArrayList[Row]()
rawData.add(Row(1.0, Vectors.sparse(3, Seq((0, 2.0), (1, 3.0)))))
rawData.add(Row(4.0, Vectors.sparse(3, Seq((0, 5.0), (2, 6.0)))))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtle: it didn't like the whitespace on this line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

val struct = StructType(
StructField("labelFoo", DoubleType, false) ::
StructField("featuresBar", VectorType, false) :: Nil
)
val df = spark.sqlContext.createDataFrame(rawData, struct)

val tempDir2 = new File(tempDir, "read_write_test_2")
val writePath = tempDir2.toURI.toString

df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writePath)

val df2 = spark.read.format("libsvm").load(writePath)
val row1 = df2.first()
val v = row1.getAs[SparseVector](1)
assert(v == Vectors.sparse(3, Seq((0, 2.0), (1, 3.0))))
Utils.deleteRecursively(tempDir2)
}

test("select features from libsvm relation") {
val df = spark.read.format("libsvm").load(path)
df.select("features").rdd.map { case Row(d: Vector) => d }.first
Expand Down