Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,33 +126,15 @@ abstract class PartitioningAwareFileIndex(
val caseInsensitiveOptions = CaseInsensitiveMap(parameters)
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
val inferredPartitionSpec = PartitioningUtils.parsePartitions(

val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.parsePartitions(
leafDirs,
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
basePaths = basePaths,
userSpecifiedSchema = userSpecifiedSchema,
caseSensitive = caseSensitive,
timeZoneId = timeZoneId)
userSpecifiedSchema match {
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
val userPartitionSchema =
combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove combineInferredAndUserSpecifiedPartitionSchema now


// we need to cast into the data type that user specified.
def castPartitionValuesToUserSchema(row: InternalRow) = {
InternalRow((0 until row.numFields).map { i =>
val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType
Cast(
Literal.create(row.get(i, dt), dt),
userPartitionSchema.fields(i).dataType,
Option(timeZoneId)).eval()
}: _*)
}

PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part =>
part.copy(values = castPartitionValuesToUserSchema(part.values))
})
case _ =>
inferredPartitionSpec
}
}

private def prunePartitions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

Expand Down Expand Up @@ -94,18 +94,34 @@ object PartitioningUtils {
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZoneId: String): PartitionSpec = {
parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId))
parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema,
caseSensitive, DateTimeUtils.getTimeZone(timeZoneId))
}

private[datasources] def parsePartitions(
paths: Seq[Path],
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
timeZone: TimeZone): PartitionSpec = {
val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we build this at the caller side out of PartitioningUtils? Then we only need one extra parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer to make the parameter simple and easy to understand. So that the logic of caller(outside the PartitioningUtils) looks cleaner.

val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
if (caseSensitive) {
CaseInsensitiveMap(nameToDataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this if !caseSensitive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for pointing it out :)

} else {
nameToDataType
}
} else {
Map.empty[String, DataType]
}

// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
parsePartition(path, typeInference, basePaths, timeZone)
parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone)
}.unzip

// We create pairs of (path -> path's partition value) here
Expand Down Expand Up @@ -147,14 +163,22 @@ object PartitioningUtils {
columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
// We always assume partition columns are nullable since we've no idea whether null values
// will be appended in the future.
StructField(name, dataType, nullable = true)
StructField(name, userSpecifiedDataTypes.getOrElse(name, dataType), nullable = true)
}
}

// Finally, we create `Partition`s based on paths and resolved partition values.
val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
case (PartitionValues(_, literals), (path, _)) =>
PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path)
case (PartitionValues(columnNames, literals), (path, _)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change?

val values = columnNames.zip(literals).map {
case (name, literal) =>
if (userSpecifiedDataTypes.contains(name)) {
Cast(literal, userSpecifiedDataTypes(name), Option(timeZone.getID)).eval()
} else {
literal.value
}
}
PartitionPath(InternalRow.fromSeq(values), path)
}

PartitionSpec(StructType(fields), partitions)
Expand Down Expand Up @@ -185,6 +209,7 @@ object PartitioningUtils {
path: Path,
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
Expand All @@ -206,7 +231,7 @@ object PartitioningUtils {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
parsePartitionColumn(currentPath.getName, typeInference, timeZone)
parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone)
maybeColumn.foreach(columns += _)

// Now, we determine if we should stop.
Expand Down Expand Up @@ -239,6 +264,7 @@ object PartitioningUtils {
private def parsePartitionColumn(
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
timeZone: TimeZone): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
Expand All @@ -250,7 +276,13 @@ object PartitioningUtils {
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")

val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
val literal = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, process the column as String
// type and cast it as user specified data type later.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do the cast here? It's a good practise to put related code together

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. the function returns Option[(String, Literal)]
  2. the function inferPartitionColumnValue is quite complex, don't want change it or write duplicated logic.

inferPartitionColumnValue(rawColumnValue, false, timeZone)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we add the cast here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my reasons above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we make it returning Option[(String, Literal)]? If not, what about Literal(Cast(...).eval())?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will use the Cast one.

} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
}
Some(columnName -> literal)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}

class FileIndexSuite extends SharedSQLContext {
Expand All @@ -49,6 +50,21 @@ class FileIndexSuite extends SharedSQLContext {
}
}

test("SPARK-26188: don't infer data types of partition columns if user specifies schema") {
withTempDir { dir =>
val partitionDirectory = new File(dir, s"a=4d")
partitionDirectory.mkdir()
val file = new File(partitionDirectory, "text.txt")
stringToFile(file, "text")
val path = new Path(dir.getCanonicalPath)
val schema = StructType(Seq(StructField("a", StringType, false)))
val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
val partitionValues = fileIndex.partitionSpec().partitions.map(_.values)
assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 &&
partitionValues(0).getString(0) == "4d")
}
}

test("InMemoryFileIndex: input paths are converted to qualified paths") {
withTempDir { dir =>
val file = new File(dir, "text.txt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
"hdfs://host:9000/path/a=10.5/b=hello")

var exception = intercept[AssertionError] {
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))

Expand All @@ -115,6 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -128,6 +130,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/something=true/table")),
None,
true,
timeZoneId)

// Valid
Expand All @@ -141,6 +145,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/table=true")),
None,
true,
timeZoneId)

// Invalid
Expand All @@ -154,6 +160,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/path/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
Expand All @@ -174,20 +182,22 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
Set(new Path("hdfs://host:9000/tmp/tables/")),
None,
true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
}

test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1
val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1
assert(expected === actual)
}

def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
parsePartition(new Path(path), true, Set.empty[Path], timeZone)
parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)
}.getMessage

assert(message.contains(expected))
Expand Down Expand Up @@ -231,6 +241,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec1.isEmpty)
Expand All @@ -240,6 +251,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
path = new Path("file://path/a=10"),
typeInference = true,
basePaths = Set(new Path("file://path")),
Map.empty,
timeZone = timeZone)._1

assert(partitionSpec2 ==
Expand All @@ -258,6 +270,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
paths.map(new Path(_)),
true,
rootPaths,
None,
true,
timeZoneId)
assert(actualSpec.partitionColumns === spec.partitionColumns)
assert(actualSpec.partitions.length === spec.partitions.length)
Expand Down Expand Up @@ -370,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partitions with type inference disabled") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
val actualSpec =
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId)
parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId)
assert(actualSpec === spec)
}

Expand Down