Skip to content

Commit 552eba4

Browse files
committed
fix python
1 parent 5b5786d commit 552eba4

File tree

6 files changed

+25
-12
lines changed

6 files changed

+25
-12
lines changed

python/pyspark/sql/context.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None):
8686
>>> df.registerTempTable("allTypes")
8787
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
8888
... 'from allTypes where b and i > 0').collect()
89-
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
89+
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
90+
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
9091
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
9192
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
9293
"""
@@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()):
176177
177178
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
178179
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
179-
[Row(c0=u'4')]
180+
[Row(_c0=u'4')]
180181
181182
>>> from pyspark.sql.types import IntegerType
182183
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
183184
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
184-
[Row(c0=4)]
185+
[Row(_c0=4)]
185186
186187
>>> from pyspark.sql.types import IntegerType
187188
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
188189
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
189-
[Row(c0=4)]
190+
[Row(_c0=4)]
190191
"""
191192
func = lambda _, it: map(lambda x: f(*x), it)
192193
ser = AutoBatchedSerializer(PickleSerializer())

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class Analyzer(
6868
Batch("Resolution", fixedPoint,
6969
ResolveRelations ::
7070
ResolveReferences ::
71-
ResolveAliases ::
7271
ResolveGroupingAnalytics ::
7372
ResolveSortReferences ::
7473
ResolveGenerate ::
7574
ResolveFunctions ::
75+
ResolveAliases ::
7676
ExtractWindowExpressions ::
7777
GlobalAggregates ::
7878
UnresolvedHavingClauseAttributes ::

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
3232
import org.apache.spark.api.java.JavaRDD
3333
import org.apache.spark.api.python.SerDeUtil
3434
import org.apache.spark.rdd.RDD
35-
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, ResolvedStar, UnresolvedAttribute}
35+
import org.apache.spark.sql.catalyst.analysis._
3636
import org.apache.spark.sql.catalyst.expressions._
3737
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
3838
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -629,7 +629,15 @@ class DataFrame private[sql](
629629
@scala.annotation.varargs
630630
def select(cols: Column*): DataFrame = {
631631
val namedExpressions = cols.map {
632-
case Column(expr: Expression) => UnresolvedAlias(expr)
632+
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
633+
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
634+
// make it a NamedExpression.
635+
case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
636+
case Column(expr: NamedExpression) => expr
637+
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
638+
// correct defaults after the nested expression's type has been resolved.
639+
case Column(explode: Explode) => MultiAlias(explode, Nil)
640+
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
633641
}
634642
// When user continuously call `select`, speed up analysis by collapsing `Project`
635643
import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
2121
import scala.language.implicitConversions
2222

2323
import org.apache.spark.annotation.Experimental
24-
import org.apache.spark.sql.catalyst.analysis.Star
24+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
2727
import org.apache.spark.sql.types.NumericType
@@ -78,6 +78,10 @@ class GroupedData protected[sql](
7878
}
7979

8080
val aliasedAgg = aggregates.map {
81+
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
82+
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
83+
// make it a NamedExpression.
84+
case u: UnresolvedAttribute => UnresolvedAlias(u)
8185
case expr: NamedExpression => expr
8286
case expr: Expression => Alias(expr, expr.prettyString)()
8387
}

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
7474
// Skip EvaluatePython nodes.
7575
case plan: EvaluatePython => plan
7676

77-
case plan: LogicalPlan =>
77+
case plan: LogicalPlan if plan.resolved =>
7878
// Extract any PythonUDFs from the current operator.
7979
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
8080
if (udfs.isEmpty) {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
13671367

13681368
test("SPARK-6145: special cases") {
13691369
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
1370-
"""{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
1371-
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
1372-
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
1370+
"""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
1371+
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
1372+
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
13731373
}
13741374

13751375
test("SPARK-6898: complete support for special chars in column names") {

0 commit comments

Comments
 (0)