Skip to content
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ class HadoopRDD[K, V](

val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop)

// Sets the thread local variable for the file's name
split.inputSplit.value match {
case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString)
case _ => SqlNewHadoopRDD.unsetInputFileName()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call SqlNewHadoopRDD.unsetInputFileName() in https://github.com/apache/spark/pull/9542/files#diff-83eb37f7b0ebed3c14ccb7bff0d577c2R257? Like what we do in SqlNewHadoopRDD?


// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
Expand Down Expand Up @@ -250,6 +256,7 @@ class HadoopRDD[K, V](

override def close() {
if (reader != null) {
SqlNewHadoopRDD.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hive.execution

import java.io.{DataInput, DataOutput}
import java.io.{PrintWriter, File, DataInput, DataOutput}
import java.util.{ArrayList, Arrays, Properties}

import org.apache.hadoop.conf.Configuration
Expand All @@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.io.Writable
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.util.Utils
Expand All @@ -44,7 +45,7 @@ case class ListStringCaseClass(l: Seq[String])
/**
* A test suite for Hive custom UDFs.
*/
class HiveUDFSuite extends QueryTest with TestHiveSingleton {
class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {

import hiveContext.{udf, sql}
import hiveContext.implicits._
Expand Down Expand Up @@ -348,6 +349,94 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton {

sqlContext.dropTempTable("testUDF")
}

test("SPARK-11522 select input_file_name from non-parquet table"){

withTempDir { tempDir =>

// EXTERNAL OpenCSVSerde table pointing to LOCATION

val file1 = new File(tempDir + "/data1")
val writer1 = new PrintWriter(file1)
writer1.write("1,2")
writer1.close()

val file2 = new File(tempDir + "/data2")
val writer2 = new PrintWriter(file2)
writer2.write("1,2")
writer2.close()

sql(
s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT)
ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'
WITH SERDEPROPERTIES (
\"separatorChar\" = \",\",
\"quoteChar\" = \"\\\"\",
\"escapeChar\" = \"\\\\\")
LOCATION '$tempDir'
""")

val answer1 =
sql("SELECT input_file_name() FROM csv_table").head().getString(0)
assert(answer1.contains("data1") || answer1.contains("data2"))

val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count()
assert(count1 == 2)
sql("DROP TABLE csv_table")

// EXTERNAL pointing to LOCATION

sql(
s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int)
ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
LOCATION '$tempDir'
""")

val answer2 =
sql("SELECT input_file_name() as file FROM external_t5").head().getString(0)
assert(answer1.contains("data1") || answer1.contains("data2"))

val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count
assert(count2 == 2)
sql("DROP TABLE external_t5")
}

withTempDir { tempDir =>

// External parquet pointing to LOCATION

val parquetLocation = tempDir + "/external_parquet"
sql("SELECT 1, 2").write.parquet(parquetLocation)

sql(
s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int)
STORED AS PARQUET
LOCATION '$parquetLocation'
""")

val answer3 =
sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0)
assert(answer3.contains("external_parquet"))

val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count
assert(count3 == 1)
sql("DROP TABLE external_parquet")
}

// Non-External parquet pointing to /tmp/...

sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " +
" STORED AS parquet " +
" AS SELECT 1, 2")

val answer4 =
sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0)
assert(answer4.contains("parquet_tmp"))

val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count
assert(count4 == 1)
sql("DROP TABLE parquet_tmp")
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down