Skip to content

Commit 93f3556

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-16213][SQL] Reduce runtime overhead of a program that creates an primitive array in DataFrame
## What changes were proposed in this pull request? This PR reduces runtime overhead of a program the creates an primitive array in DataFrame by using the similar approach to #15044. Generated code performs boxing operation in an assignment from InternalRow to an `Object[]` temporary array (at Lines 051 and 061 in the generated code before without this PR). If we know that type of array elements is primitive, we apply the following optimizations: 1. Eliminate a pair of `isNullAt()` and a null assignment 2. Allocate an primitive array instead of `Object[]` (eliminate boxing operations) 3. Create `UnsafeArrayData` by using `UnsafeArrayWriter` to keep a primitive array in a row format instead of doing non-lightweight operations in constructor of `GenericArrayData` The PR also performs the same things for `CreateMap`. Here are performance results of [DataFrame programs](https://github.com/kiszk/spark/blob/6bf54ec5e227689d69f6db991e9ecbc54e153d0a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala#L83-L112) by up to 17.9x over without this PR. ``` Without SPARK-16043 OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Read a primitive array in DataFrame: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 3805 / 4150 0.0 507308.9 1.0X Double 3593 / 3852 0.0 479056.9 1.1X With SPARK-16043 Read a primitive array in DataFrame: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 213 / 271 0.0 28387.5 1.0X Double 204 / 223 0.0 27250.9 1.0X ``` Note : #15780 is enabled for these measurements An motivating example ``` java val df = sparkContext.parallelize(Seq(0.0d, 1.0d), 1).toDF df.selectExpr("Array(value + 1.1d, value + 2.2d)").show ``` Generated code without this PR ``` java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ private Object[] project_values; /* 013 */ private UnsafeRow project_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter; /* 017 */ /* 018 */ public GeneratedIterator(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ inputadapter_input = inputs[0]; /* 026 */ serializefromobject_result = new UnsafeRow(1); /* 027 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 028 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 029 */ this.project_values = null; /* 030 */ project_result = new UnsafeRow(1); /* 031 */ this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32); /* 032 */ this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 033 */ this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 034 */ /* 035 */ } /* 036 */ /* 037 */ protected void processNext() throws java.io.IOException { /* 038 */ while (inputadapter_input.hasNext()) { /* 039 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 040 */ double inputadapter_value = inputadapter_row.getDouble(0); /* 041 */ /* 042 */ final boolean project_isNull = false; /* 043 */ this.project_values = new Object[2]; /* 044 */ boolean project_isNull1 = false; /* 045 */ /* 046 */ double project_value1 = -1.0; /* 047 */ project_value1 = inputadapter_value + 1.1D; /* 048 */ if (false) { /* 049 */ project_values[0] = null; /* 050 */ } else { /* 051 */ project_values[0] = project_value1; /* 052 */ } /* 053 */ /* 054 */ boolean project_isNull4 = false; /* 055 */ /* 056 */ double project_value4 = -1.0; /* 057 */ project_value4 = inputadapter_value + 2.2D; /* 058 */ if (false) { /* 059 */ project_values[1] = null; /* 060 */ } else { /* 061 */ project_values[1] = project_value4; /* 062 */ } /* 063 */ /* 064 */ final ArrayData project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_values); /* 065 */ this.project_values = null; /* 066 */ project_holder.reset(); /* 067 */ /* 068 */ project_rowWriter.zeroOutNullBytes(); /* 069 */ /* 070 */ if (project_isNull) { /* 071 */ project_rowWriter.setNullAt(0); /* 072 */ } else { /* 073 */ // Remember the current cursor so that we can calculate how many bytes are /* 074 */ // written later. /* 075 */ final int project_tmpCursor = project_holder.cursor; /* 076 */ /* 077 */ if (project_value instanceof UnsafeArrayData) { /* 078 */ final int project_sizeInBytes = ((UnsafeArrayData) project_value).getSizeInBytes(); /* 079 */ // grow the global buffer before writing data. /* 080 */ project_holder.grow(project_sizeInBytes); /* 081 */ ((UnsafeArrayData) project_value).writeToMemory(project_holder.buffer, project_holder.cursor); /* 082 */ project_holder.cursor += project_sizeInBytes; /* 083 */ /* 084 */ } else { /* 085 */ final int project_numElements = project_value.numElements(); /* 086 */ project_arrayWriter.initialize(project_holder, project_numElements, 8); /* 087 */ /* 088 */ for (int project_index = 0; project_index < project_numElements; project_index++) { /* 089 */ if (project_value.isNullAt(project_index)) { /* 090 */ project_arrayWriter.setNullDouble(project_index); /* 091 */ } else { /* 092 */ final double project_element = project_value.getDouble(project_index); /* 093 */ project_arrayWriter.write(project_index, project_element); /* 094 */ } /* 095 */ } /* 096 */ } /* 097 */ /* 098 */ project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor); /* 099 */ } /* 100 */ project_result.setTotalSize(project_holder.totalSize()); /* 101 */ append(project_result); /* 102 */ if (shouldStop()) return; /* 103 */ } /* 104 */ } /* 105 */ } ``` Generated code with this PR ``` java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow serializefromobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 012 */ private UnsafeArrayData project_arrayData; /* 013 */ private UnsafeRow project_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter project_arrayWriter; /* 017 */ /* 018 */ public GeneratedIterator(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ inputadapter_input = inputs[0]; /* 026 */ serializefromobject_result = new UnsafeRow(1); /* 027 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 028 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 029 */ /* 030 */ project_result = new UnsafeRow(1); /* 031 */ this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32); /* 032 */ this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1); /* 033 */ this.project_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 034 */ /* 035 */ } /* 036 */ /* 037 */ protected void processNext() throws java.io.IOException { /* 038 */ while (inputadapter_input.hasNext()) { /* 039 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 040 */ double inputadapter_value = inputadapter_row.getDouble(0); /* 041 */ /* 042 */ byte[] project_array = new byte[32]; /* 043 */ project_arrayData = new UnsafeArrayData(); /* 044 */ Platform.putLong(project_array, 16, 2); /* 045 */ project_arrayData.pointTo(project_array, 16, 32); /* 046 */ /* 047 */ boolean project_isNull1 = false; /* 048 */ /* 049 */ double project_value1 = -1.0; /* 050 */ project_value1 = inputadapter_value + 1.1D; /* 051 */ if (false) { /* 052 */ project_arrayData.setNullAt(0); /* 053 */ } else { /* 054 */ project_arrayData.setDouble(0, project_value1); /* 055 */ } /* 056 */ /* 057 */ boolean project_isNull4 = false; /* 058 */ /* 059 */ double project_value4 = -1.0; /* 060 */ project_value4 = inputadapter_value + 2.2D; /* 061 */ if (false) { /* 062 */ project_arrayData.setNullAt(1); /* 063 */ } else { /* 064 */ project_arrayData.setDouble(1, project_value4); /* 065 */ } /* 066 */ project_holder.reset(); /* 067 */ /* 068 */ // Remember the current cursor so that we can calculate how many bytes are /* 069 */ // written later. /* 070 */ final int project_tmpCursor = project_holder.cursor; /* 071 */ /* 072 */ if (project_arrayData instanceof UnsafeArrayData) { /* 073 */ final int project_sizeInBytes = ((UnsafeArrayData) project_arrayData).getSizeInBytes(); /* 074 */ // grow the global buffer before writing data. /* 075 */ project_holder.grow(project_sizeInBytes); /* 076 */ ((UnsafeArrayData) project_arrayData).writeToMemory(project_holder.buffer, project_holder.cursor); /* 077 */ project_holder.cursor += project_sizeInBytes; /* 078 */ /* 079 */ } else { /* 080 */ final int project_numElements = project_arrayData.numElements(); /* 081 */ project_arrayWriter.initialize(project_holder, project_numElements, 8); /* 082 */ /* 083 */ for (int project_index = 0; project_index < project_numElements; project_index++) { /* 084 */ if (project_arrayData.isNullAt(project_index)) { /* 085 */ project_arrayWriter.setNullDouble(project_index); /* 086 */ } else { /* 087 */ final double project_element = project_arrayData.getDouble(project_index); /* 088 */ project_arrayWriter.write(project_index, project_element); /* 089 */ } /* 090 */ } /* 091 */ } /* 092 */ /* 093 */ project_rowWriter.setOffsetAndSize(0, project_tmpCursor, project_holder.cursor - project_tmpCursor); /* 094 */ project_result.setTotalSize(project_holder.totalSize()); /* 095 */ append(project_result); /* 096 */ if (shouldStop()) return; /* 097 */ } /* 098 */ } /* 099 */ } ``` ## How was this patch tested? Added unit tests into `DataFrameComplexTypeSuite` Author: Kazuaki Ishizaki <[email protected]> Author: Liang-Chi Hsieh <[email protected]> Closes #13909 from kiszk/SPARK-16213.
1 parent 092c672 commit 93f3556

File tree

9 files changed

+230
-91
lines changed

9 files changed

+230
-91
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,58 @@ public UnsafeMapData getMap(int ordinal) {
287287
return map;
288288
}
289289

290+
@Override
291+
public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
292+
293+
public void setNullAt(int ordinal) {
294+
assertIndexIsValid(ordinal);
295+
BitSetMethods.set(baseObject, baseOffset + 8, ordinal);
296+
297+
/* we assume the corrresponding column was already 0 or
298+
will be set to 0 later by the caller side */
299+
}
300+
301+
public void setBoolean(int ordinal, boolean value) {
302+
assertIndexIsValid(ordinal);
303+
Platform.putBoolean(baseObject, getElementOffset(ordinal, 1), value);
304+
}
305+
306+
public void setByte(int ordinal, byte value) {
307+
assertIndexIsValid(ordinal);
308+
Platform.putByte(baseObject, getElementOffset(ordinal, 1), value);
309+
}
310+
311+
public void setShort(int ordinal, short value) {
312+
assertIndexIsValid(ordinal);
313+
Platform.putShort(baseObject, getElementOffset(ordinal, 2), value);
314+
}
315+
316+
public void setInt(int ordinal, int value) {
317+
assertIndexIsValid(ordinal);
318+
Platform.putInt(baseObject, getElementOffset(ordinal, 4), value);
319+
}
320+
321+
public void setLong(int ordinal, long value) {
322+
assertIndexIsValid(ordinal);
323+
Platform.putLong(baseObject, getElementOffset(ordinal, 8), value);
324+
}
325+
326+
public void setFloat(int ordinal, float value) {
327+
if (Float.isNaN(value)) {
328+
value = Float.NaN;
329+
}
330+
assertIndexIsValid(ordinal);
331+
Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value);
332+
}
333+
334+
public void setDouble(int ordinal, double value) {
335+
if (Double.isNaN(value)) {
336+
value = Double.NaN;
337+
}
338+
assertIndexIsValid(ordinal);
339+
Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value);
340+
}
341+
290342
// This `hashCode` computation could consume much processor time for large data.
291343
// If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes
292344
// are used to compute `hashCode` (See `Vector.hashCode`).

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 109 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
22-
import org.apache.spark.sql.catalyst.analysis.Star
2322
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2423
import org.apache.spark.sql.catalyst.expressions.codegen._
2524
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
2625
import org.apache.spark.sql.types._
26+
import org.apache.spark.unsafe.Platform
27+
import org.apache.spark.unsafe.array.ByteArrayMethods
2728
import org.apache.spark.unsafe.types.UTF8String
2829

2930
/**
@@ -43,7 +44,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
4344
override def checkInputDataTypes(): TypeCheckResult =
4445
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
4546

46-
override def dataType: DataType = {
47+
override def dataType: ArrayType = {
4748
ArrayType(
4849
children.headOption.map(_.dataType).getOrElse(NullType),
4950
containsNull = children.exists(_.nullable))
@@ -56,33 +57,99 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
5657
}
5758

5859
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
59-
val arrayClass = classOf[GenericArrayData].getName
60-
val values = ctx.freshName("values")
61-
ctx.addMutableState("Object[]", values, s"this.$values = null;")
62-
63-
ev.copy(code = s"""
64-
this.$values = new Object[${children.size}];""" +
65-
ctx.splitExpressions(
66-
ctx.INPUT_ROW,
67-
children.zipWithIndex.map { case (e, i) =>
68-
val eval = e.genCode(ctx)
69-
eval.code + s"""
70-
if (${eval.isNull}) {
71-
$values[$i] = null;
72-
} else {
73-
$values[$i] = ${eval.value};
74-
}
75-
"""
76-
}) +
77-
s"""
78-
final ArrayData ${ev.value} = new $arrayClass($values);
79-
this.$values = null;
80-
""", isNull = "false")
60+
val et = dataType.elementType
61+
val evals = children.map(e => e.genCode(ctx))
62+
val (preprocess, assigns, postprocess, arrayData) =
63+
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
64+
ev.copy(
65+
code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess,
66+
value = arrayData,
67+
isNull = "false")
8168
}
8269

8370
override def prettyName: String = "array"
8471
}
8572

73+
private [sql] object GenArrayData {
74+
/**
75+
* Return Java code pieces based on DataType and isPrimitive to allocate ArrayData class
76+
*
77+
* @param ctx a [[CodegenContext]]
78+
* @param elementType data type of underlying array elements
79+
* @param elementsCode a set of [[ExprCode]] for each element of an underlying array
80+
* @param isMapKey if true, throw an exception when the element is null
81+
* @return (code pre-assignments, assignments to each array elements, code post-assignments,
82+
* arrayData name)
83+
*/
84+
def genCodeToCreateArrayData(
85+
ctx: CodegenContext,
86+
elementType: DataType,
87+
elementsCode: Seq[ExprCode],
88+
isMapKey: Boolean): (String, Seq[String], String, String) = {
89+
val arrayName = ctx.freshName("array")
90+
val arrayDataName = ctx.freshName("arrayData")
91+
val numElements = elementsCode.length
92+
93+
if (!ctx.isPrimitiveType(elementType)) {
94+
val genericArrayClass = classOf[GenericArrayData].getName
95+
ctx.addMutableState("Object[]", arrayName,
96+
s"this.$arrayName = new Object[${numElements}];")
97+
98+
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
99+
val isNullAssignment = if (!isMapKey) {
100+
s"$arrayName[$i] = null;"
101+
} else {
102+
"throw new RuntimeException(\"Cannot use null as map key!\");"
103+
}
104+
eval.code + s"""
105+
if (${eval.isNull}) {
106+
$isNullAssignment
107+
} else {
108+
$arrayName[$i] = ${eval.value};
109+
}
110+
"""
111+
}
112+
113+
("",
114+
assignments,
115+
s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
116+
arrayDataName)
117+
} else {
118+
val unsafeArraySizeInBytes =
119+
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
120+
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
121+
val baseOffset = Platform.BYTE_ARRAY_OFFSET
122+
ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
123+
124+
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
125+
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
126+
val isNullAssignment = if (!isMapKey) {
127+
s"$arrayDataName.setNullAt($i);"
128+
} else {
129+
"throw new RuntimeException(\"Cannot use null as map key!\");"
130+
}
131+
eval.code + s"""
132+
if (${eval.isNull}) {
133+
$isNullAssignment
134+
} else {
135+
$arrayDataName.set$primitiveValueTypeName($i, ${eval.value});
136+
}
137+
"""
138+
}
139+
140+
(s"""
141+
byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
142+
$arrayDataName = new UnsafeArrayData();
143+
Platform.putLong($arrayName, $baseOffset, $numElements);
144+
$arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
145+
""",
146+
assignments,
147+
"",
148+
arrayDataName)
149+
}
150+
}
151+
}
152+
86153
/**
87154
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
88155
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
@@ -133,49 +200,26 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
133200
}
134201

135202
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
136-
val arrayClass = classOf[GenericArrayData].getName
137203
val mapClass = classOf[ArrayBasedMapData].getName
138-
val keyArray = ctx.freshName("keyArray")
139-
val valueArray = ctx.freshName("valueArray")
140-
ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
141-
ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
142-
143-
val keyData = s"new $arrayClass($keyArray)"
144-
val valueData = s"new $arrayClass($valueArray)"
145-
ev.copy(code = s"""
146-
$keyArray = new Object[${keys.size}];
147-
$valueArray = new Object[${values.size}];""" +
148-
ctx.splitExpressions(
149-
ctx.INPUT_ROW,
150-
keys.zipWithIndex.map { case (key, i) =>
151-
val eval = key.genCode(ctx)
152-
s"""
153-
${eval.code}
154-
if (${eval.isNull}) {
155-
throw new RuntimeException("Cannot use null as map key!");
156-
} else {
157-
$keyArray[$i] = ${eval.value};
158-
}
159-
"""
160-
}) +
161-
ctx.splitExpressions(
162-
ctx.INPUT_ROW,
163-
values.zipWithIndex.map { case (value, i) =>
164-
val eval = value.genCode(ctx)
165-
s"""
166-
${eval.code}
167-
if (${eval.isNull}) {
168-
$valueArray[$i] = null;
169-
} else {
170-
$valueArray[$i] = ${eval.value};
171-
}
172-
"""
173-
}) +
204+
val MapType(keyDt, valueDt, _) = dataType
205+
val evalKeys = keys.map(e => e.genCode(ctx))
206+
val evalValues = values.map(e => e.genCode(ctx))
207+
val (preprocessKeyData, assignKeys, postprocessKeyData, keyArrayData) =
208+
GenArrayData.genCodeToCreateArrayData(ctx, keyDt, evalKeys, true)
209+
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
210+
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
211+
val code =
174212
s"""
175-
final MapData ${ev.value} = new $mapClass($keyData, $valueData);
176-
this.$keyArray = null;
177-
this.$valueArray = null;
178-
""", isNull = "false")
213+
final boolean ${ev.isNull} = false;
214+
$preprocessKeyData
215+
${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
216+
$postprocessKeyData
217+
$preprocessValueData
218+
${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
219+
$postprocessValueData
220+
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
221+
"""
222+
ev.copy(code = code)
179223
}
180224

181225
override def prettyName: String = "map"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
4242

4343
def array: Array[Any]
4444

45+
def setNullAt(i: Int): Unit
46+
47+
def update(i: Int, value: Any): Unit
48+
49+
// default implementation (slow)
50+
def setBoolean(i: Int, value: Boolean): Unit = update(i, value)
51+
def setByte(i: Int, value: Byte): Unit = update(i, value)
52+
def setShort(i: Int, value: Short): Unit = update(i, value)
53+
def setInt(i: Int, value: Int): Unit = update(i, value)
54+
def setLong(i: Int, value: Long): Unit = update(i, value)
55+
def setFloat(i: Int, value: Float): Unit = update(i, value)
56+
def setDouble(i: Int, value: Double): Unit = update(i, value)
57+
4558
def toBooleanArray(): Array[Boolean] = {
4659
val size = numElements()
4760
val values = new Array[Boolean](size)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
7171
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
7272
override def getMap(ordinal: Int): MapData = getAs(ordinal)
7373

74+
override def setNullAt(ordinal: Int): Unit = array(ordinal) = null
75+
76+
override def update(ordinal: Int, value: Any): Unit = array(ordinal) = value
77+
7478
override def toString(): String = array.mkString("[", ",", "]")
7579

7680
override def equals(o: Any): Boolean = {

0 commit comments

Comments
 (0)