Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Timestamp

import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, GlobFilter, Path}
import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.sql.SparkSession
Expand All @@ -30,8 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan,
GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.sources.{And, DataSourceRegister, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SerializableConfiguration
Expand Down Expand Up @@ -108,36 +107,19 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
if (pathGlobPattern.forall(new GlobFilter(_).accept(fsPath))) {
val fs = fsPath.getFileSystem(broadcastedHadoopConf.value.value)
val fileStatus = fs.getFileStatus(fsPath)
val length = fileStatus.getLen
val modificationTime = fileStatus.getModificationTime

if (filterFuncs.forall(_.apply(fileStatus))) {
def readContent: Array[Byte] = {
val stream = fs.open(fsPath)
val content = try {
try {
ByteStreams.toByteArray(stream)
} finally {
Closeables.close(stream, true)
}
}

val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
}
val requiredOutput = fullOutput.filter { a =>
requiredSchema.fieldNames.contains(a.name)
}

// TODO: Add column pruning
// currently it still read the file content even if content column is not required.
val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)

val internalRow = InternalRow(
UTF8String.fromString(path),
DateTimeUtils.fromMillis(modificationTime),
length,
content
)

Iterator(requiredColumns(internalRow))
if (filterFuncs.forall(_.apply(fileStatus))) {
val row = genPrunedRow(path, fileStatus, readContent, requiredSchema.fieldNames)
Iterator(row)
} else {
Iterator.empty
}
Expand Down Expand Up @@ -206,6 +188,23 @@ object BinaryFileFormat {
case _ => (_ => true)
}
}

private[binaryfile] def genPrunedRow(
path: String,
status: FileStatus,
readContent: => Array[Byte],
requiredFieldNames: Array[String]): InternalRow = {

val values = requiredFieldNames.map {
case PATH => UTF8String.fromString(path)
case LENGTH => status.getLen
case MODIFICATION_TIME => DateTimeUtils.fromMillis(status.getModificationTime)
case CONTENT => readContent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a strong reason to prune other columns that are inexpensive. Code is much simpler if we only prune content.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think current code is simpler.
The previous code contains some code which is hard to read:

          val fullOutput = dataSchema.map { f =>
            AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
          }
          val requiredOutput = fullOutput.filter { a =>
            requiredSchema.fieldNames.contains(a.name)
          }
          val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
          ...
          Iterator(requiredColumns(internalRow))

case name => throw new RuntimeException(s"Unexcepted field name: ${name}")
}
InternalRow(values: _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to change inferSchema() or still return content field with null values? cc: @cloud-fan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding a "keep invalid" option, when file read error, fill content column "null"?
Now when file loaded error, the datasource loading broken.

}

}

class BinaryFileSourceOptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.binaryfile

import java.io.File
import java.nio.file.{Files, StandardOpenOption}
import java.nio.file.attribute.PosixFilePermission
import java.sql.Timestamp

import scala.collection.JavaConverters._
Expand All @@ -28,11 +29,14 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, GlobFilter, Path}
import org.mockito.Mockito.{mock, when}

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Expand All @@ -44,6 +48,8 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest

private var fs: FileSystem = _

private var file1: File = _

private var file1Status: FileStatus = _

override def beforeAll(): Unit = {
Expand All @@ -58,7 +64,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
val year2015Dir = new File(testDir, "year=2015")
year2015Dir.mkdir()

val file1 = new File(year2014Dir, "data.txt")
file1 = new File(year2014Dir, "data.txt")
Files.write(
file1.toPath,
Seq("2014-test").asJava,
Expand Down Expand Up @@ -286,4 +292,45 @@ class BinaryFileFormatSuite extends QueryTest with SharedSQLContext with SQLTest
EqualTo(MODIFICATION_TIME, file1Status.getModificationTime)
), true)
}


test("genPrunedRow") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just test buildReader on one file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need (and how to) test when pruned, the file is actually not read ?

val path1 = "test:/path/to/dir1"
val len1 = 123L
val time1 = 4567L
val content1 = "abcd".getBytes
val status1 = mock(classOf[FileStatus])
when(status1.getLen).thenReturn(len1)
when(status1.getModificationTime).thenReturn(time1)

var readContent1Called = false
def readContent1: Array[Byte] = {
readContent1Called = true
content1
}

def test(fieldNames: String*): Unit = {
readContent1Called = false

val row = genPrunedRow(path1, status1, readContent1, fieldNames.toArray)
val expectedRowVals = fieldNames.toArray.map {
case PATH => UTF8String.fromString(path1)
case LENGTH => len1
case MODIFICATION_TIME => DateTimeUtils.fromMillis(time1)
case CONTENT => content1
}
val expectedRow = InternalRow(expectedRowVals: _*)
assert(row === expectedRow)
assert(fieldNames.contains(CONTENT) === readContent1Called)
}

test("path", "length", "modificationTime", "content")
test("path", "length", "modificationTime")
test("path", "modificationTime", "content")
test("path", "length")
test("path", "content", "modificationTime", "length")
test("path")
test("length")
test("content")
}
}