Skip to content

Commit 191bc0d

Browse files
Changing to Seq for ArrayType, refactoring SQLParser for nested field extension
1 parent cbb5793 commit 191bc0d

6 files changed

Lines changed: 144 additions & 162 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -66,43 +66,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
6666
protected case class Keyword(str: String)
6767

6868
protected implicit def asParser(k: Keyword): Parser[String] =
69-
allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
70-
71-
protected class SqlLexical extends StdLexical {
72-
case class FloatLit(chars: String) extends Token {
73-
override def toString = chars
74-
}
75-
override lazy val token: Parser[Token] = (
76-
identChar ~ rep( identChar | digit ) ^^
77-
{ case first ~ rest => processIdent(first :: rest mkString "") }
78-
| rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
79-
case i ~ None => NumericLit(i mkString "")
80-
case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
81-
}
82-
| '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
83-
{ case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
84-
| '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
85-
{ case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
86-
| EofCh ^^^ EOF
87-
| '\'' ~> failure("unclosed string literal")
88-
| '\"' ~> failure("unclosed string literal")
89-
| delim
90-
| failure("illegal character")
91-
)
92-
93-
override def identChar = letter | elem('.') | elem('_') | elem('[') | elem(']')
94-
95-
override def whitespace: Parser[Any] = rep(
96-
whitespaceChar
97-
| '/' ~ '*' ~ comment
98-
| '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
99-
| '#' ~ rep( chrExcept(EofCh, '\n') )
100-
| '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
101-
| '/' ~ '*' ~ failure("unclosed comment")
102-
)
103-
}
104-
105-
override val lexical = new SqlLexical
69+
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
10670

10771
protected val ALL = Keyword("ALL")
10872
protected val AND = Keyword("AND")
@@ -161,24 +125,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
161125
this.getClass
162126
.getMethods
163127
.filter(_.getReturnType == classOf[Keyword])
164-
.map(_.invoke(this).asInstanceOf[Keyword])
165-
166-
/** Generate all variations of upper and lower case of a given string */
167-
private def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
168-
if (s == "") {
169-
Stream(prefix)
170-
} else {
171-
allCaseVersions(s.tail, prefix + s.head.toLower) ++
172-
allCaseVersions(s.tail, prefix + s.head.toUpper)
173-
}
174-
}
128+
.map(_.invoke(this).asInstanceOf[Keyword].str)
175129

176-
lexical.reserved ++= reservedWords.flatMap(w => allCaseVersions(w.str))
177-
178-
lexical.delimiters += (
179-
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
180-
",", ";", "%", "{", "}", ":", "[", "]"
181-
)
130+
override val lexical = new SqlLexical(reservedWords)
182131

183132
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
184133
exprs.zipWithIndex.map {
@@ -383,14 +332,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
383332
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
384333

385334
protected lazy val baseExpression: PackratParser[Expression] =
386-
expression ~ "[" ~ expression <~ "]" ^^ {
335+
expression ~ "[" ~ expression <~ "]" ^^ {
387336
case base ~ _ ~ ordinal => GetItem(base, ordinal)
388337
} |
389338
TRUE ^^^ Literal(true, BooleanType) |
390339
FALSE ^^^ Literal(false, BooleanType) |
391340
cast |
392341
"(" ~> expression <~ ")" |
393-
"[" ~> literal <~ "]" |
394342
function |
395343
"-" ~> literal ^^ UnaryMinus |
396344
ident ^^ UnresolvedAttribute |
@@ -400,3 +348,55 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
400348
protected lazy val dataType: Parser[DataType] =
401349
STRING ^^^ StringType
402350
}
351+
352+
class SqlLexical(val keywords: Seq[String]) extends StdLexical {
353+
case class FloatLit(chars: String) extends Token {
354+
override def toString = chars
355+
}
356+
357+
reserved ++= keywords.flatMap(w => allCaseVersions(w))
358+
359+
delimiters += (
360+
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
361+
",", ";", "%", "{", "}", ":", "[", "]"
362+
)
363+
364+
override lazy val token: Parser[Token] = (
365+
identChar ~ rep( identChar | digit ) ^^
366+
{ case first ~ rest => processIdent(first :: rest mkString "") }
367+
| rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
368+
case i ~ None => NumericLit(i mkString "")
369+
case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
370+
}
371+
| '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
372+
{ case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
373+
| '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
374+
{ case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
375+
| EofCh ^^^ EOF
376+
| '\'' ~> failure("unclosed string literal")
377+
| '\"' ~> failure("unclosed string literal")
378+
| delim
379+
| failure("illegal character")
380+
)
381+
382+
override def identChar = letter | elem('_') | elem('.')
383+
384+
override def whitespace: Parser[Any] = rep(
385+
whitespaceChar
386+
| '/' ~ '*' ~ comment
387+
| '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
388+
| '#' ~ rep( chrExcept(EofCh, '\n') )
389+
| '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
390+
| '/' ~ '*' ~ failure("unclosed comment")
391+
)
392+
393+
/** Generate all variations of upper and lower case of a given string */
394+
def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
395+
if (s == "") {
396+
Stream(prefix)
397+
} else {
398+
allCaseVersions(s.tail, prefix + s.head.toLower) ++
399+
allCaseVersions(s.tail, prefix + s.head.toUpper)
400+
}
401+
}
402+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
5050
null
5151
} else {
5252
if (child.dataType.isInstanceOf[ArrayType]) {
53-
val baseValue = value.asInstanceOf[Array[_]]
53+
// TODO: consider using Array[_] for ArrayType child to avoid
54+
// boxing of primitives
55+
val baseValue = value.asInstanceOf[Seq[_]]
5456
val o = key.asInstanceOf[Int]
5557
if (o >= baseValue.size || o < 0) {
5658
null

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 9 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -58,53 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
5858
* can contain ordinal expressions, such as `field[i][j][k]...`.
5959
*/
6060
def resolve(name: String): Option[NamedExpression] = {
61-
def expandFunc(expType: (Expression, DataType), field: String): (Expression, DataType) = {
62-
val (exp, t) = expType
63-
val ordinalRegExp = """(\[(\d+|\w+)\])""".r
64-
val fieldName = if (ordinalRegExp.findFirstIn(field).isDefined) {
65-
field.substring(0, field.indexOf("["))
66-
} else {
67-
field
68-
}
69-
t match {
70-
case ArrayType(elementType) =>
71-
val ordinals = ordinalRegExp
72-
.findAllIn(field)
73-
.matchData
74-
.map(_.group(2))
75-
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
76-
GetItem(v1, Literal(v2.toInt))), elementType)
77-
case MapType(keyType, valueType) =>
78-
val ordinals = ordinalRegExp
79-
.findAllIn(field)
80-
.matchData
81-
.map(_.group(2))
82-
// TODO: we should recover the JVM type of keyType to match the
83-
// actual type of the key?! should we restrict ourselves to NativeType?
84-
(ordinals.foldLeft(exp)((v1: Expression, v2: String) =>
85-
GetItem(v1, Literal(v2, keyType))), valueType)
86-
case StructType(fields) =>
87-
val structField = fields
88-
.find(_.name == fieldName)
89-
if (!structField.isDefined) {
90-
throw new TreeNodeException(
91-
this, s"Trying to resolve Attribute but field ${fieldName} is not defined")
92-
}
93-
structField.get.dataType match {
94-
case ArrayType(elementType) =>
95-
val ordinals = ordinalRegExp.findAllIn(field).matchData.map(_.group(2))
96-
(ordinals.foldLeft(
97-
GetField(exp, fieldName).asInstanceOf[Expression])((v1: Expression, v2: String) =>
98-
GetItem(v1, Literal(v2.toInt))),
99-
elementType)
100-
case _ =>
101-
(GetField(exp, fieldName), structField.get.dataType)
102-
}
103-
case _ =>
104-
expType
105-
}
106-
}
107-
61+
// TODO: extend SqlParser to handle field expressions
10862
val parts = name.split("\\.")
10963
// Collect all attributes that are output by this nodes children where either the first part
11064
// matches the name or where the first part matches the scope and the second part matches the
@@ -124,33 +78,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
12478
remainingParts.head
12579
}
12680
if (option.name == relevantRemaining) (option, remainingParts.tail.toList) :: Nil else Nil*/
81+
// If the first part of the desired name matches a qualifier for this possible match, drop it.
82+
/* TODO: from rebase!
83+
val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts
84+
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
85+
*/
12786
}
12887

12988
options.distinct match {
130-
case (a, Nil) :: Nil => {
131-
a.dataType match {
132-
case ArrayType(_) | MapType(_, _) =>
133-
val expression = expandFunc((a: Expression, a.dataType), name)._1
134-
Some(Alias(expression, name)())
135-
case _ => Some(a)
136-
}
137-
} // One match, no nested fields, use it.
89+
case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
13890
// One match, but we also need to extract the requested nested field.
13991
case (a, nestedFields) :: Nil =>
14092
a.dataType match {
14193
case StructType(fields) =>
142-
// this is compatibility reasons with earlier code!
143-
// TODO: why only nestedFields and not parts?
144-
// check for absence of nested arrays so there are only fields
145-
if ((parts(0) :: nestedFields).forall(!_.matches("\\w*\\[(\\d+|\\w+)\\]+"))) {
146-
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
147-
} else {
148-
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
149-
Some(Alias(expression, nestedFields.last)())
150-
}
151-
case _ =>
152-
val expression = parts.foldLeft((a: Expression, a.dataType))(expandFunc)._1
153-
Some(Alias(expression, nestedFields.last)())
94+
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
95+
case _ => None // Don't know how to resolve these field references
15496
}
15597
case Nil => None // No matches.
15698
case ambiguousReferences =>

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ private[sql] object CatalystConverter {
6363
val MAP_VALUE_SCHEMA_NAME = "value"
6464
val MAP_SCHEMA_NAME = "map"
6565

66-
type ArrayScalaType[T] = Array[T]
66+
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
67+
type ArrayScalaType[T] = Seq[T]
6768
type StructScalaType[T] = Seq[T]
6869
type MapScalaType[K, V] = Map[K, V]
6970

@@ -426,7 +427,7 @@ private[parquet] class CatalystArrayConverter(
426427
override def end(): Unit = {
427428
assert(parent != null)
428429
// here we need to make sure to use ArrayScalaType
429-
parent.updateField(index, buffer.toArray)
430+
parent.updateField(index, buffer.toArray.toSeq)
430431
clearBuffer()
431432
}
432433
}
@@ -451,8 +452,7 @@ private[parquet] class CatalystNativeArrayConverter(
451452

452453
type NativeType = elementType.JvmType
453454

454-
private var buffer: CatalystConverter.ArrayScalaType[NativeType] =
455-
elementType.classTag.newArray(capacity)
455+
private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity)
456456

457457
private var elements: Int = 0
458458

@@ -526,15 +526,14 @@ private[parquet] class CatalystNativeArrayConverter(
526526
// here we need to make sure to use ArrayScalaType
527527
parent.updateField(
528528
index,
529-
buffer.slice(0, elements))
529+
buffer.slice(0, elements).toSeq)
530530
clearBuffer()
531531
}
532532

533533
private def checkGrowBuffer(): Unit = {
534534
if (elements >= capacity) {
535535
val newCapacity = 2 * capacity
536-
val tmp: CatalystConverter.ArrayScalaType[NativeType] =
537-
elementType.classTag.newArray(newCapacity)
536+
val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity)
538537
Array.copy(buffer, 0, tmp, 0, capacity)
539538
buffer = tmp
540539
capacity = newCapacity

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,13 @@ private[sql] object ParquetTestData {
377377
val map2 = r1.addGroup(2)
378378
val keyValue3 = map2.addGroup(0)
379379
// TODO: currently only string key type supported
380-
keyValue3.add(0, "7")
380+
keyValue3.add(0, "seven")
381381
val valueGroup1 = keyValue3.addGroup(1)
382382
valueGroup1.add(0, 42.toLong)
383383
valueGroup1.add(1, "the answer")
384384
val keyValue4 = map2.addGroup(0)
385385
// TODO: currently only string key type supported
386-
keyValue4.add(0, "8")
386+
keyValue4.add(0, "eight")
387387
val valueGroup2 = keyValue4.addGroup(1)
388388
valueGroup2.add(0, 49.toLong)
389389

0 commit comments

Comments
 (0)