Skip to content

Commit 39edb0a

Browse files
committed
Fix and test.
1 parent 643728d commit 39edb0a

File tree

2 files changed

+83
-10
lines changed

2 files changed

+83
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,9 @@ object WithFields {
625625
val name = namePartsRemaining.head
626626
if (namePartsRemaining.length == 1) {
627627
col.dataType match {
628-
case ArrayType(et: StructType, containsNull) =>
628+
case ArrayType(et, containsNull) =>
629629
val lv = NamedLambdaVariable("arg", et, containsNull)
630-
val function = WithFields(lv, name :: Nil, value :: Nil)
630+
val function = withFieldHelper(lv, name :: Nil, value)
631631
ArrayTransform(col, LambdaFunction(function, Seq(lv)))
632632

633633
case _: StructType =>
@@ -638,21 +638,21 @@ object WithFields {
638638
}
639639
} else {
640640
val newNamesRemaining = namePartsRemaining.tail
641-
val resolver = SQLConf.get.resolver
642-
643-
val newCol = ExtractValue(col, Literal(name), resolver)
644-
val newValue = withFieldHelper(
645-
col = newCol,
646-
namePartsRemaining = newNamesRemaining,
647-
value = value)
648641

649642
col.dataType match {
650643
case ArrayType(et, containsNull) =>
651644
val lv = NamedLambdaVariable("arg", et, containsNull)
652-
val function = withFieldHelper(lv, namePartsRemaining, newValue)
645+
val function = withFieldHelper(lv, namePartsRemaining, value)
653646
ArrayTransform(col, LambdaFunction(function, Seq(lv)))
654647

655648
case _: StructType =>
649+
val resolver = SQLConf.get.resolver
650+
val newCol = ExtractValue(col, Literal(name), resolver)
651+
val newValue = withFieldHelper(
652+
col = newCol,
653+
namePartsRemaining = newNamesRemaining,
654+
value = value)
655+
656656
WithFields(col, name :: Nil, newValue :: Nil)
657657

658658
case dt =>

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,4 +1802,77 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
18021802
StructField("b", IntegerType, nullable = false))),
18031803
containsNull = false), nullable = false))))
18041804
}
1805+
1806+
private lazy val arrayStructArrayLevel1: DataFrame = spark.createDataFrame(
1807+
sparkContext.parallelize(Row(Array(Row(Array(Row(1, null, 3)), null, 3))) :: Nil),
1808+
StructType(
1809+
Seq(StructField("a", ArrayType(
1810+
StructType(Seq(
1811+
StructField("a", arrayType, nullable = false),
1812+
StructField("b", IntegerType, nullable = true),
1813+
StructField("c", IntegerType, nullable = false))),
1814+
containsNull = false)))))
1815+
1816+
test("withField should add and replace field to struct of array of array") {
1817+
checkAnswerAndSchema(
1818+
arrayStructArrayLevel1.withColumn("a", $"a".withField("a.d", lit(2))),
1819+
Row(Seq(Row(Seq(Row(1, null, 3, 2)), null, 3))) :: Nil,
1820+
StructType(
1821+
Seq(StructField("a", ArrayType(
1822+
StructType(Seq(
1823+
StructField("a", ArrayType(
1824+
StructType(Seq(
1825+
StructField("a", IntegerType, nullable = false),
1826+
StructField("b", IntegerType, nullable = true),
1827+
StructField("c", IntegerType, nullable = false),
1828+
StructField("d", IntegerType, nullable = false))),
1829+
containsNull = true), nullable = false),
1830+
StructField("b", IntegerType, nullable = true),
1831+
StructField("c", IntegerType, nullable = false))),
1832+
containsNull = false)))))
1833+
1834+
checkAnswerAndSchema(
1835+
arrayStructArrayLevel1.withColumn("a", $"a.a".withField("d", lit(2))),
1836+
Row(Seq(Seq(Row(1, null, 3, 2)))) :: Nil,
1837+
StructType(
1838+
Seq(StructField("a", ArrayType(
1839+
ArrayType(
1840+
StructType(Seq(
1841+
StructField("a", IntegerType, nullable = false),
1842+
StructField("b", IntegerType, nullable = true),
1843+
StructField("c", IntegerType, nullable = false),
1844+
StructField("d", IntegerType, nullable = false))),
1845+
containsNull = true),
1846+
containsNull = false)))))
1847+
1848+
checkAnswerAndSchema(
1849+
arrayStructArrayLevel1.withColumn("a", $"a".withField("a.b", lit(2))),
1850+
Row(Seq(Row(Seq(Row(1, 2, 3)), null, 3))) :: Nil,
1851+
StructType(
1852+
Seq(StructField("a", ArrayType(
1853+
StructType(Seq(
1854+
StructField("a", ArrayType(
1855+
StructType(Seq(
1856+
StructField("a", IntegerType, nullable = false),
1857+
StructField("b", IntegerType, nullable = false),
1858+
StructField("c", IntegerType, nullable = false))),
1859+
containsNull = true), nullable = false),
1860+
StructField("b", IntegerType, nullable = true),
1861+
StructField("c", IntegerType, nullable = false))),
1862+
containsNull = false)))))
1863+
1864+
checkAnswerAndSchema(
1865+
arrayStructArrayLevel1.withColumn("a", $"a.a".withField("b", lit(2))),
1866+
Row(Seq(Seq(Row(1, 2, 3)))) :: Nil,
1867+
StructType(
1868+
Seq(StructField("a", ArrayType(
1869+
ArrayType(
1870+
StructType(Seq(
1871+
StructField("a", IntegerType, nullable = false),
1872+
StructField("b", IntegerType, nullable = false),
1873+
StructField("c", IntegerType, nullable = false))),
1874+
containsNull = true),
1875+
containsNull = false)))))
1876+
1877+
}
18051878
}

0 commit comments

Comments
 (0)