Skip to content

Commit 9a63852

Browse files
committed
improve the unsafe row writing framework
1 parent bc36b0f commit 9a63852

7 files changed

Lines changed: 235 additions & 77 deletions

File tree

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,36 @@
2121
import org.apache.spark.unsafe.Platform;
2222

2323
/**
24-
* A helper class to manage the row buffer when construct unsafe rows.
24+
* A helper class to manage the data buffer for an unsafe row. The data buffer can grow and
25+
* automatically re-point the unsafe row to it.
26+
*
27+
* This class can be used to build a one-pass unsafe row writing program, i.e. data will be written
28+
* to the data buffer directly and no extra copy is needed. There should be only one instance of
29+
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
30+
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
31+
* and reuse the data buffer.
2532
*/
2633
public class BufferHolder {
2734
public byte[] buffer;
2835
public int cursor = Platform.BYTE_ARRAY_OFFSET;
36+
private final UnsafeRow row;
37+
private final int fixedSize;
2938

30-
public BufferHolder() {
31-
this(64);
39+
public BufferHolder(UnsafeRow row) {
40+
this(row, 64);
3241
}
3342

34-
public BufferHolder(int size) {
35-
buffer = new byte[size];
43+
public BufferHolder(UnsafeRow row, int initialSize) {
44+
this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
45+
this.buffer = new byte[fixedSize + initialSize];
46+
this.row = row;
47+
this.row.pointTo(buffer, buffer.length);
3648
}
3749

3850
/**
39-
* Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer.
51+
* Grows the buffer by at least neededSize and points the row to the buffer.
4052
*/
41-
public void grow(int neededSize, UnsafeRow row) {
53+
public void grow(int neededSize) {
4254
final int length = totalSize() + neededSize;
4355
if (buffer.length < length) {
4456
// This will not happen frequently, because the buffer is re-used.
@@ -50,22 +62,12 @@ public void grow(int neededSize, UnsafeRow row) {
5062
Platform.BYTE_ARRAY_OFFSET,
5163
totalSize());
5264
buffer = tmp;
53-
if (row != null) {
54-
row.pointTo(buffer, length * 2);
55-
}
65+
row.pointTo(buffer, buffer.length);
5666
}
5767
}
5868

59-
public void grow(int neededSize) {
60-
grow(neededSize, null);
61-
}
62-
6369
public void reset() {
64-
cursor = Platform.BYTE_ARRAY_OFFSET;
65-
}
66-
public void resetTo(int offset) {
67-
assert(offset <= buffer.length);
68-
cursor = Platform.BYTE_ARRAY_OFFSET + offset;
70+
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
6971
}
7072

7173
public int totalSize() {

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,44 @@
2626
import org.apache.spark.unsafe.types.UTF8String;
2727

2828
/**
29-
* A helper class to write data into global row buffer using `UnsafeRow` format,
30-
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
29+
* A helper class to write data into global row buffer using `UnsafeRow` format.
30+
*
31+
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
32+
* buffer while writing. If a new record comes, the cursor of row buffer will be reset, so we need
33+
* to also call `reset` of this class before writing, to update the `startingOffset` and clear out
34+
* null bits. Note that if we use it to write data into the result unsafe row, which means we will
35+
* always write from the very beginning of the global row buffer, we don't need to update
36+
* `startingOffset` and can just call `zeroOutNullBites` before writing new record.
3137
*/
3238
public class UnsafeRowWriter {
3339

34-
private BufferHolder holder;
40+
private final BufferHolder holder;
3541
// The offset of the global buffer where we start to write this row.
3642
private int startingOffset;
37-
private int nullBitsSize;
38-
private UnsafeRow row;
43+
private final int nullBitsSize;
44+
private final int fixedSize;
3945

40-
public void initialize(BufferHolder holder, int numFields) {
41-
this.holder = holder;
46+
public void reset() {
4247
this.startingOffset = holder.cursor;
43-
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
4448

4549
// grow the global buffer to make sure it has enough space to write fixed-length data.
46-
final int fixedSize = nullBitsSize + 8 * numFields;
47-
holder.grow(fixedSize, row);
50+
holder.grow(fixedSize);
4851
holder.cursor += fixedSize;
4952

50-
// zero-out the null bits region
53+
zeroOutNullBites();
54+
}
55+
56+
public void zeroOutNullBites() {
5157
for (int i = 0; i < nullBitsSize; i += 8) {
5258
Platform.putLong(holder.buffer, startingOffset + i, 0L);
5359
}
5460
}
5561

56-
public void initialize(UnsafeRow row, BufferHolder holder, int numFields) {
57-
initialize(holder, numFields);
58-
this.row = row;
62+
public UnsafeRowWriter(BufferHolder holder, int numFields) {
63+
this.holder = holder;
64+
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
65+
this.fixedSize = nullBitsSize + 8 * numFields;
66+
this.startingOffset = holder.cursor;
5967
}
6068

6169
private void zeroOutPaddingBytes(int numBytes) {
@@ -98,7 +106,7 @@ public void alignToWords(int numBytes) {
98106

99107
if (remainder > 0) {
100108
final int paddingBytes = 8 - remainder;
101-
holder.grow(paddingBytes, row);
109+
holder.grow(paddingBytes);
102110

103111
for (int i = 0; i < paddingBytes; i++) {
104112
Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
@@ -161,7 +169,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
161169
}
162170
} else {
163171
// grow the global buffer before writing data.
164-
holder.grow(16, row);
172+
holder.grow(16);
165173

166174
// zero-out the bytes
167175
Platform.putLong(holder.buffer, holder.cursor, 0L);
@@ -193,7 +201,7 @@ public void write(int ordinal, UTF8String input) {
193201
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
194202

195203
// grow the global buffer before writing data.
196-
holder.grow(roundedSize, row);
204+
holder.grow(roundedSize);
197205

198206
zeroOutPaddingBytes(numBytes);
199207

@@ -214,7 +222,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {
214222
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
215223

216224
// grow the global buffer before writing data.
217-
holder.grow(roundedSize, row);
225+
holder.grow(roundedSize);
218226

219227
zeroOutPaddingBytes(numBytes);
220228

@@ -230,7 +238,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {
230238

231239
public void write(int ordinal, CalendarInterval input) {
232240
// grow the global buffer before writing data.
233-
holder.grow(16, row);
241+
holder.grow(16);
234242

235243
// Write the months and microseconds fields of Interval to the variable length portion.
236244
Platform.putLong(holder.buffer, holder.cursor, input.months);

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4343
case _ => false
4444
}
4545

46-
private val rowWriterClass = classOf[UnsafeRowWriter].getName
47-
private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
48-
4946
// TODO: if the nullability of field is correct, we can use it to save null check.
5047
private def writeStructToBuffer(
5148
ctx: CodegenContext,
@@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7370
row: String,
7471
inputs: Seq[ExprCode],
7572
inputTypes: Seq[DataType],
76-
bufferHolder: String): String = {
73+
bufferHolder: String,
74+
isTopLevel: Boolean = false): String = {
75+
val rowWriterClass = classOf[UnsafeRowWriter].getName
7776
val rowWriter = ctx.freshName("rowWriter")
78-
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
77+
ctx.addMutableState(rowWriterClass, rowWriter,
78+
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
79+
80+
val resetWriter = if (isTopLevel) {
81+
// For top level row writer, it always writes to the beginning of the global buffer holder,
82+
// which means its fixed-size region always in the same position, so we don't need to call
83+
// `reset` to set up its fixed-size region every time.
84+
if (inputs.map(_.isNull).forall(_ == "false")) {
85+
// If all fields are not nullable, which means the null bits never changes, then we don't
86+
// need to clear it out every time.
87+
""
88+
} else {
89+
s"$rowWriter.zeroOutNullBites();"
90+
}
91+
} else {
92+
s"$rowWriter.reset();"
93+
}
7994

8095
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
8196
case ((input, dataType), index) =>
@@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
122137
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
123138
"""
124139

125-
case _ if ctx.isPrimitiveType(dt) =>
126-
s"""
127-
$rowWriter.write($index, ${input.value});
128-
"""
129-
130140
case t: DecimalType =>
131141
s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"
132142

@@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
153163
}
154164

155165
s"""
156-
$rowWriter.initialize($bufferHolder, ${inputs.length});
166+
$resetWriter
157167
${ctx.splitExpressions(row, writeFields)}
158168
""".trim
159169
}
@@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
164174
input: String,
165175
elementType: DataType,
166176
bufferHolder: String): String = {
177+
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
167178
val arrayWriter = ctx.freshName("arrayWriter")
168179
ctx.addMutableState(arrayWriterClass, arrayWriter,
169180
s"this.$arrayWriter = new $arrayWriterClass();")
@@ -288,22 +299,44 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
288299
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
289300
val exprTypes = expressions.map(_.dataType)
290301

302+
val numVarLenFields = exprTypes.count {
303+
case dt if ctx.isPrimitiveType(dt) => false
304+
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => false
305+
// TODO: consider large decimal and interval type
306+
case _ => true
307+
}
308+
291309
val result = ctx.freshName("result")
292310
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
293-
val bufferHolder = ctx.freshName("bufferHolder")
311+
312+
val holder = ctx.freshName("holder")
294313
val holderClass = classOf[BufferHolder].getName
295-
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
314+
ctx.addMutableState(holderClass, holder,
315+
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
316+
317+
val resetBufferHolder = if (numVarLenFields == 0) {
318+
""
319+
} else {
320+
s"$holder.reset();"
321+
}
322+
val updateRowSize = if (numVarLenFields == 0) {
323+
""
324+
} else {
325+
s"$result.setTotalSize($holder.totalSize());"
326+
}
296327

297328
// Reset the subexpression values for each row.
298329
val subexprReset = ctx.subExprResetVariables.mkString("\n")
299330

331+
val writeExpressions =
332+
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
333+
300334
val code =
301335
s"""
302-
$bufferHolder.reset();
336+
$resetBufferHolder
303337
$subexprReset
304-
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
305-
306-
$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
338+
$writeExpressions
339+
$updateRowSize
307340
"""
308341
ExprCode(code, "false", result)
309342
}

0 commit comments

Comments
 (0)