diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala index f1a1ca6616a2..80abc0ac75ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -57,7 +57,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { override def toString = chars } - reserved ++= keywords.flatMap(w => allCaseVersions(w)) + reserved ++= keywords.flatMap(w => Stream(w.toLowerCase()) ) delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", @@ -66,7 +66,13 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { override lazy val token: Parser[Token] = ( identChar ~ (identChar | digit).* ^^ - { case first ~ rest => processIdent((first :: rest).mkString) } + { case first ~ rest => + val rsIdent = processIdent((first :: rest).mkString.toLowerCase) + if(rsIdent.getClass.getCanonicalName.contains("StdTokens.Keyword")) + Keyword(rsIdent.chars.toLowerCase()) + else + processIdent((first :: rest).mkString) + } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) @@ -95,15 +101,6 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | '/' ~ '*' ~ failure("unclosed comment") ).* - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s.isEmpty) { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) #::: - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } } /** @@ -139,8 +136,7 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLE = Keyword("TABLE") protected val UNCACHE = Keyword("UNCACHE") - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected implicit def asParser(k: Keyword): Parser[String] = k.str.toLowerCase private val reservedWords: Seq[String] = this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f79d4ff444dc..cfc0fdf863f5 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -36,8 +36,7 @@ import org.apache.spark.sql.catalyst.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected implicit def asParser(k: Keyword): Parser[String] = k.str.toLowerCase protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/KeyWordParserSuit.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/KeyWordParserSuit.scala new file mode 100644 index 000000000000..395d15015b26 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/KeyWordParserSuit.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import scala.language.implicitConversions +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers +import scala.util.parsing.input.CharArrayReader.EofCh +import scala.util.parsing.combinator.lexical._ +import org.scalatest.FunSuite + +class KeyWordParserSuit extends FunSuite { + + val testDDL = s""" + |creAtE TEMPORARY TABLE hbase_people + |USING com.shengli.spark.hbase + |OPTIONS ( + | sparksql_table_schema '(row_key string, name string, age int, job string)', + | hbase_table_name 'people', + | hbase_table_schema '(:key , profile:name , profile:age , career:job )' + |) + |sERDEPRopERTIES ( + | path 'temp_path' + |) + |TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST( + | test 'test_keyword' + )""".stripMargin + + val allCaseVersionsParser = new AllCaseVersionsParser() + val lowerCaseKeyWordParser = new LowerCaseParser() + var ret = "" + + test("SPARK-5009 reproduce the stackoverflow exception") { + try { + val rs =allCaseVersionsParser(testDDL) + } + catch { + case e: java.lang.StackOverflowError => + ret = "stackoverflow" + println("stackoverflow when keyword using all case versions") + } + assert(ret=="stackoverflow") + } + + test("SPARK-5009 fix the stackoverflow exception with keyword lower case way") { + try { + val rs =lowerCaseKeyWordParser(testDDL) + ret = rs.get + } + catch { + case e: java.lang.StackOverflowError => + ret = "stackoverflow" + } + assert(ret=="parse success") + } +} + + +class AllCaseVersionsParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): Option[String] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + None + } + } + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val AS = Keyword("AS") + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val SERDEPROPERTIES = Keyword("SERDEPROPERTIES") + protected val TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST = + Keyword("TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = this.getClass.getMethods.filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new LowerCaseSqlLexical(reservedWords) + + protected lazy val ddl: Parser[String] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[String] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ~ + (SERDEPROPERTIES~>serde) ~ (TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST ~> test).? ^^ { + case tableName ~ provider ~ opts ~ sd ~ tst => + "parse success" + } + + protected lazy val test: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val serde: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + +} + + + + +class LowerCaseParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): Option[String] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + None + } + } + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = k.str.toLowerCase + + protected val AS = Keyword("AS") + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + protected val SERDEPROPERTIES = Keyword("SERDEPROPERTIES") + protected val TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST = + Keyword("TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass.getMethods.filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new LowerCaseSqlLexical(reservedWords) + + protected lazy val ddl: Parser[String] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[String] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ~ + (SERDEPROPERTIES~>serde) ~ (TESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTESTTEST ~> test).? ^^ { + case tableName ~ provider ~ opts ~ sd ~ tst => + "parse success" + } + + protected lazy val test: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val serde: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + +} + +/* + * This class demonstrate the all case versions , if keyword is long, the allCaseVersions generate a long Stream. + * In Parser, when called `asParser` method, the reduce(_|_) will cause stackoverflow exception + */ +class AllCaseVersionsSqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w) ) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~" + ) + + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { case first ~ rest => processIdent((first :: rest).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} + + +/* + * This class demonstrate the lower case keyword matching strategy + * Will not cause stackoverflow exception and speed up keyword matching + */ +class LowerCaseSqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => Stream(w.toLowerCase()) ) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~" + ) + + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { + case first ~ rest => + val rsIdent = processIdent((first :: rest).mkString.toLowerCase()) + if(rsIdent.getClass.getCanonicalName.contains("StdTokens.Keyword")) + Keyword(rsIdent.chars.toLowerCase()) + else + processIdent((first :: rest).mkString) + } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 8a66ac31f2df..95752b920816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -46,8 +46,7 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected case class Keyword(str: String) - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected implicit def asParser(k: Keyword): Parser[String] = k.str.toLowerCase protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index ebf7003ff9e5..6a913df366dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { - protected implicit def asParser(k: Keyword): Parser[String] = - lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected implicit def asParser(k: Keyword): Parser[String] = k.str.toLowerCase protected val ADD = Keyword("ADD") protected val DFS = Keyword("DFS")