Skip to content

Commit 6676e74

Browse files
author
Davies Liu
committed
defer dictionary decoding
1 parent c7fccb5 commit 6676e74

File tree

10 files changed

+195
-125
lines changed

10 files changed

+195
-125
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ object Decimal {
340340
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
341341
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
342342

343+
/** Maximum number of decimal digits a Int can represent */
344+
val MAX_INT_DIGITS = 9
345+
343346
/** Maximum number of decimal digits a Long can represent */
344347
val MAX_LONG_DIGITS = 18
345348

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType {
150150
}
151151
}
152152

153+
/**
154+
* Returns if dt is a DecimalType that fits inside a int
155+
*/
156+
def is32BitDecimalType(dt: DataType): Boolean = {
157+
dt match {
158+
case t: DecimalType =>
159+
t.precision <= Decimal.MAX_LONG_DIGITS
160+
case _ => false
161+
}
162+
}
163+
153164
/**
154165
* Returns if dt is a DecimalType that fits inside a long
155166
*/

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ private void initializeInternal() throws IOException {
257257
throw new IOException("Unsupported type: " + t);
258258
}
259259
if (originalTypes[i] == OriginalType.DECIMAL &&
260-
primitiveType.getDecimalMetadata().getPrecision() >
261-
CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
260+
primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
262261
throw new IOException("Decimal with high precision is not supported.");
263262
}
264263
if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
@@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept
439438
PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
440439
int precision = type.getDecimalMetadata().getPrecision();
441440
int scale = type.getDecimalMetadata().getScale();
442-
Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
441+
Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
443442
"Unsupported precision.");
444443

445444
for (int n = 0; n < num; ++n) {
@@ -611,6 +610,11 @@ private boolean next() throws IOException {
611610
*/
612611
private void readBatch(int total, ColumnVector column) throws IOException {
613612
int rowId = 0;
613+
if (useDictionary) {
614+
dictionaryIds = column.reserveDictionaryIds(total);
615+
} else {
616+
column.setDictionary(null);
617+
}
614618
while (total > 0) {
615619
// Compute the number of values we want to read in this page.
616620
int leftInPage = (int)(endOfPageValueCount - valuesRead);
@@ -620,13 +624,6 @@ private void readBatch(int total, ColumnVector column) throws IOException {
620624
}
621625
int num = Math.min(total, leftInPage);
622626
if (useDictionary) {
623-
// Data is dictionary encoded. We will vector decode the ids and then resolve the values.
624-
if (dictionaryIds == null) {
625-
dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
626-
} else {
627-
dictionaryIds.reset();
628-
dictionaryIds.reserve(total);
629-
}
630627
// Read and decode dictionary ids.
631628
defColumn.readIntegers(
632629
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -672,21 +669,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
672669
switch (descriptor.getType()) {
673670
case INT32:
674671
if (column.dataType() == DataTypes.IntegerType) {
675-
for (int i = rowId; i < rowId + num; ++i) {
676-
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
677-
}
672+
column.setDictionary(dictionary);
678673
} else if (column.dataType() == DataTypes.ByteType) {
679-
for (int i = rowId; i < rowId + num; ++i) {
680-
column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
681-
}
674+
column.setDictionary(dictionary);
682675
} else if (column.dataType() == DataTypes.ShortType) {
683-
for (int i = rowId; i < rowId + num; ++i) {
684-
column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
685-
}
676+
column.setDictionary(dictionary);
686677
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
687-
for (int i = rowId; i < rowId + num; ++i) {
688-
column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
689-
}
678+
column.setDictionary(dictionary);
690679
} else {
691680
throw new NotImplementedException("Unimplemented type: " + column.dataType());
692681
}
@@ -695,28 +684,28 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
695684
case INT64:
696685
if (column.dataType() == DataTypes.LongType ||
697686
DecimalType.is64BitDecimalType(column.dataType())) {
698-
for (int i = rowId; i < rowId + num; ++i) {
699-
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
700-
}
687+
column.setDictionary(dictionary);
701688
} else {
702689
throw new NotImplementedException("Unimplemented type: " + column.dataType());
703690
}
704691
break;
705692

706693
case FLOAT:
707-
for (int i = rowId; i < rowId + num; ++i) {
708-
column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
709-
}
694+
column.setDictionary(dictionary);
710695
break;
711696

712697
case DOUBLE:
713-
for (int i = rowId; i < rowId + num; ++i) {
714-
column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
715-
}
698+
column.setDictionary(dictionary);
716699
break;
717700

718701
case FIXED_LEN_BYTE_ARRAY:
719-
if (DecimalType.is64BitDecimalType(column.dataType())) {
702+
// This is the legacy mode to write DecimalType
703+
if (DecimalType.is32BitDecimalType(column.dataType())) {
704+
for (int i = rowId; i < rowId + num; ++i) {
705+
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
706+
column.putInt(i,(int) CatalystRowConverter.binaryToUnscaledLong(v));
707+
}
708+
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
720709
for (int i = rowId; i < rowId + num; ++i) {
721710
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
722711
column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
@@ -727,14 +716,7 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
727716
break;
728717

729718
case BINARY:
730-
// TODO: this is incredibly inefficient as it blows up the dictionary right here. We
731-
// need to do this better. We should probably add the dictionary data to the ColumnVector
732-
// and reuse it across batches. This should mean adding a ByteArray would just update
733-
// the length and offset.
734-
for (int i = rowId; i < rowId + num; ++i) {
735-
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
736-
column.putByteArray(i, v.getBytes());
737-
}
719+
column.setDictionary(dictionary);
738720
break;
739721

740722
default:

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
import java.math.BigDecimal;
2020
import java.math.BigInteger;
2121

22+
import org.apache.commons.lang.NotImplementedException;
23+
import org.apache.parquet.column.Dictionary;
24+
import org.apache.parquet.io.api.Binary;
25+
2226
import org.apache.spark.memory.MemoryMode;
2327
import org.apache.spark.sql.catalyst.InternalRow;
2428
import org.apache.spark.sql.catalyst.util.ArrayData;
@@ -27,8 +31,6 @@
2731
import org.apache.spark.unsafe.types.CalendarInterval;
2832
import org.apache.spark.unsafe.types.UTF8String;
2933

30-
import org.apache.commons.lang.NotImplementedException;
31-
3234
/**
3335
* This class represents a column of values and provides the main APIs to access the data
3436
* values. It supports all the types and contains get/put APIs as well as their batched versions.
@@ -204,28 +206,17 @@ public float getFloat(int ordinal) {
204206

205207
@Override
206208
public Decimal getDecimal(int ordinal, int precision, int scale) {
207-
if (precision <= Decimal.MAX_LONG_DIGITS()) {
208-
return Decimal.apply(getLong(ordinal), precision, scale);
209-
} else {
210-
byte[] bytes = getBinary(ordinal);
211-
BigInteger bigInteger = new BigInteger(bytes);
212-
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
213-
return Decimal.apply(javaDecimal, precision, scale);
214-
}
209+
return getDecimal(offset + ordinal, precision, scale);
215210
}
216211

217212
@Override
218213
public UTF8String getUTF8String(int ordinal) {
219-
Array child = data.getByteArray(offset + ordinal);
220-
return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length);
214+
return getUTF8String(offset + ordinal);
221215
}
222216

223217
@Override
224218
public byte[] getBinary(int ordinal) {
225-
ColumnVector.Array array = data.getByteArray(offset + ordinal);
226-
byte[] bytes = new byte[array.length];
227-
System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
228-
return bytes;
219+
return getBinary(offset + ordinal);
229220
}
230221

231222
@Override
@@ -540,6 +531,42 @@ public final Array getByteArray(int rowId) {
540531
return array;
541532
}
542533

534+
public final Decimal getDecimal(int rowId, int precision, int scale) {
535+
if (precision <= Decimal.MAX_INT_DIGITS()) {
536+
return Decimal.apply(getInt(rowId), precision, scale);
537+
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
538+
return Decimal.apply(getLong(rowId), precision, scale);
539+
} else {
540+
// TODO: best perf?
541+
byte[] bytes = getBinary(rowId);
542+
BigInteger bigInteger = new BigInteger(bytes);
543+
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
544+
return Decimal.apply(javaDecimal, precision, scale);
545+
}
546+
}
547+
548+
public final UTF8String getUTF8String(int rowId) {
549+
if (dictionary == null) {
550+
ColumnVector.Array a = getByteArray(rowId);
551+
return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
552+
} else {
553+
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
554+
return UTF8String.fromBytes(v.getBytes());
555+
}
556+
}
557+
558+
public final byte[] getBinary(int rowId) {
559+
if (dictionary == null) {
560+
ColumnVector.Array array = getByteArray(rowId);
561+
byte[] bytes = new byte[array.length];
562+
System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
563+
return bytes;
564+
} else {
565+
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId));
566+
return v.getBytes();
567+
}
568+
}
569+
543570
/**
544571
* Append APIs. These APIs all behave similarly and will append data to the current vector. It
545572
* is not valid to mix the put and append APIs. The append APIs are slower and should only be
@@ -816,6 +843,39 @@ public final int appendStruct(boolean isNull) {
816843
*/
817844
protected final ColumnarBatch.Row resultStruct;
818845

846+
/**
847+
* The Dictionary for this column.
848+
*
849+
* If it's not null, will be used to decode the value in getXXX().
850+
*/
851+
protected Dictionary dictionary;
852+
853+
/**
854+
* Reusable column for ids of dictionary.
855+
*/
856+
protected ColumnVector dictionaryIds;
857+
858+
/**
859+
* Update the dictionary.
860+
*/
861+
public void setDictionary(Dictionary dictionary) {
862+
this.dictionary = dictionary;
863+
}
864+
865+
/**
866+
* Reserve a integer column for ids of dictionary.
867+
*/
868+
public ColumnVector reserveDictionaryIds(int capacity) {
869+
if (dictionaryIds == null) {
870+
dictionaryIds = allocate(capacity, DataTypes.IntegerType,
871+
this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP);
872+
} else {
873+
dictionaryIds.reset();
874+
dictionaryIds.reserve(capacity);
875+
}
876+
return dictionaryIds;
877+
}
878+
819879
/**
820880
* Sets up the common state and also handles creating the child columns if this is a nested
821881
* type.

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
*/
1717
package org.apache.spark.sql.execution.vectorized;
1818

19-
import java.math.BigDecimal;
20-
import java.math.BigInteger;
2119
import java.util.Arrays;
2220
import java.util.Iterator;
2321

22+
import org.apache.commons.lang.NotImplementedException;
23+
2424
import org.apache.spark.memory.MemoryMode;
2525
import org.apache.spark.sql.catalyst.InternalRow;
2626
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
@@ -31,8 +31,6 @@
3131
import org.apache.spark.unsafe.types.CalendarInterval;
3232
import org.apache.spark.unsafe.types.UTF8String;
3333

34-
import org.apache.commons.lang.NotImplementedException;
35-
3634
/**
3735
* This class is the in memory representation of rows as they are streamed through operators. It
3836
* is designed to maximize CPU efficiency and not storage footprint. Since it is expected that
@@ -193,29 +191,17 @@ public final boolean anyNull() {
193191

194192
@Override
195193
public final Decimal getDecimal(int ordinal, int precision, int scale) {
196-
if (precision <= Decimal.MAX_LONG_DIGITS()) {
197-
return Decimal.apply(getLong(ordinal), precision, scale);
198-
} else {
199-
// TODO: best perf?
200-
byte[] bytes = getBinary(ordinal);
201-
BigInteger bigInteger = new BigInteger(bytes);
202-
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
203-
return Decimal.apply(javaDecimal, precision, scale);
204-
}
194+
return columns[ordinal].getDecimal(rowId, precision, scale);
205195
}
206196

207197
@Override
208198
public final UTF8String getUTF8String(int ordinal) {
209-
ColumnVector.Array a = columns[ordinal].getByteArray(rowId);
210-
return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
199+
return columns[ordinal].getUTF8String(rowId);
211200
}
212201

213202
@Override
214203
public final byte[] getBinary(int ordinal) {
215-
ColumnVector.Array array = columns[ordinal].getByteArray(rowId);
216-
byte[] bytes = new byte[array.length];
217-
System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
218-
return bytes;
204+
return columns[ordinal].getBinary(rowId);
219205
}
220206

221207
@Override

0 commit comments

Comments
 (0)