Skip to content

Commit be508a5

Browse files
icexellossBryanCutler
authored andcommitted
Fix conversion for String type; refactor related functions to Arrow.scala
changed tests to use existing SQLTestData and removed unused files closes apache#14
1 parent a4b958e commit be508a5

File tree

5 files changed

+282
-205
lines changed

5 files changed

+282
-205
lines changed
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.collection.JavaConverters._
21+
import scala.language.implicitConversions
22+
23+
import io.netty.buffer.ArrowBuf
24+
import org.apache.arrow.memory.RootAllocator
25+
import org.apache.arrow.vector.BitVector
26+
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
27+
import org.apache.arrow.vector.types.FloatingPointPrecision
28+
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
29+
30+
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.types._
32+
33+
object Arrow {
34+
35+
/**
36+
* Compute the number of bytes needed to build validity map. According to
37+
* [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps),
38+
* the length of the validity bitmap should be multiples of 64 bytes.
39+
*/
40+
private def numBytesOfBitmap(numOfRows: Int): Int = {
41+
Math.ceil(numOfRows / 64.0).toInt * 8
42+
}
43+
44+
private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = {
45+
dataType match {
46+
case NullType =>
47+
case BooleanType =>
48+
buf.writeBoolean(false)
49+
case ShortType =>
50+
buf.writeShort(0)
51+
case IntegerType =>
52+
buf.writeInt(0)
53+
case LongType =>
54+
buf.writeLong(0L)
55+
case FloatType =>
56+
buf.writeFloat(0f)
57+
case DoubleType =>
58+
buf.writeDouble(0d)
59+
case ByteType =>
60+
buf.writeByte(0)
61+
case _ =>
62+
throw new UnsupportedOperationException(
63+
s"Unsupported data type ${dataType.simpleString}")
64+
}
65+
}
66+
67+
/**
68+
* Get an entry from the InternalRow, and then set to ArrowBuf.
69+
* Note: No Null check for the entry.
70+
*/
71+
private def getAndSetToArrow(
72+
row: InternalRow,
73+
buf: ArrowBuf,
74+
dataType: DataType,
75+
ordinal: Int): Unit = {
76+
dataType match {
77+
case NullType =>
78+
case BooleanType =>
79+
buf.writeBoolean(row.getBoolean(ordinal))
80+
case ShortType =>
81+
buf.writeShort(row.getShort(ordinal))
82+
case IntegerType =>
83+
buf.writeInt(row.getInt(ordinal))
84+
case LongType =>
85+
buf.writeLong(row.getLong(ordinal))
86+
case FloatType =>
87+
buf.writeFloat(row.getFloat(ordinal))
88+
case DoubleType =>
89+
buf.writeDouble(row.getDouble(ordinal))
90+
case ByteType =>
91+
buf.writeByte(row.getByte(ordinal))
92+
case _ =>
93+
throw new UnsupportedOperationException(
94+
s"Unsupported data type ${dataType.simpleString}")
95+
}
96+
}
97+
98+
/**
99+
* Transfer an array of InternalRow to an ArrowRecordBatch.
100+
*/
101+
def internalRowsToArrowRecordBatch(
102+
rows: Array[InternalRow],
103+
schema: StructType,
104+
allocator: RootAllocator): ArrowRecordBatch = {
105+
val bufAndField = schema.fields.zipWithIndex.map { case (field, ordinal) =>
106+
internalRowToArrowBuf(rows, ordinal, field, allocator)
107+
}
108+
109+
val buffers = bufAndField.flatMap(_._1).toList.asJava
110+
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
111+
112+
new ArrowRecordBatch(rows.length, fieldNodes, buffers)
113+
}
114+
115+
/**
116+
* Convert an array of InternalRow to an ArrowBuf.
117+
*/
118+
def internalRowToArrowBuf(
119+
rows: Array[InternalRow],
120+
ordinal: Int,
121+
field: StructField,
122+
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
123+
val numOfRows = rows.length
124+
125+
field.dataType match {
126+
case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
127+
val validityVector = new BitVector("validity", allocator)
128+
val validityMutator = validityVector.getMutator
129+
validityVector.allocateNew(numOfRows)
130+
validityMutator.setValueCount(numOfRows)
131+
132+
val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
133+
var nullCount = 0
134+
var index = 0
135+
while (index < rows.length) {
136+
val row = rows(index)
137+
if (row.isNullAt(ordinal)) {
138+
nullCount += 1
139+
validityMutator.set(index, 0)
140+
fillArrow(buf, field.dataType)
141+
} else {
142+
validityMutator.set(index, 1)
143+
getAndSetToArrow(row, buf, field.dataType, ordinal)
144+
}
145+
index += 1
146+
}
147+
148+
val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
149+
150+
(Array(validityVector.getBuffer, buf), Array(fieldNode))
151+
152+
case StringType =>
153+
val validityVector = new BitVector("validity", allocator)
154+
val validityMutator = validityVector.getMutator()
155+
validityVector.allocateNew(numOfRows)
156+
validityMutator.setValueCount(numOfRows)
157+
158+
val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
159+
var bytesCount = 0
160+
bufOffset.writeInt(bytesCount)
161+
val bufValues = allocator.buffer(1024)
162+
var nullCount = 0
163+
rows.zipWithIndex.foreach { case (row, index) =>
164+
if (row.isNullAt(ordinal)) {
165+
nullCount += 1
166+
validityMutator.set(index, 0)
167+
bufOffset.writeInt(bytesCount)
168+
} else {
169+
validityMutator.set(index, 1)
170+
val bytes = row.getUTF8String(ordinal).getBytes
171+
bytesCount += bytes.length
172+
bufOffset.writeInt(bytesCount)
173+
bufValues.writeBytes(bytes)
174+
}
175+
}
176+
177+
val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
178+
179+
(Array(validityVector.getBuffer, bufOffset, bufValues),
180+
Array(fieldNode))
181+
}
182+
}
183+
184+
private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
185+
val arrowFields = schema.fields.map(sparkFieldToArrowField(_))
186+
new Schema(arrowFields.toList.asJava)
187+
}
188+
189+
private[sql] def sparkFieldToArrowField(sparkField: StructField): Field = {
190+
val name = sparkField.name
191+
val dataType = sparkField.dataType
192+
val nullable = sparkField.nullable
193+
194+
dataType match {
195+
case StructType(fields) =>
196+
val childrenFields = fields.map(sparkFieldToArrowField(_)).toList.asJava
197+
new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields)
198+
case _ =>
199+
new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
200+
}
201+
}
202+
203+
/**
204+
* Transform Spark DataType to Arrow ArrowType.
205+
*/
206+
private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = {
207+
dt match {
208+
case IntegerType =>
209+
new ArrowType.Int(8 * IntegerType.defaultSize, true)
210+
case LongType =>
211+
new ArrowType.Int(8 * LongType.defaultSize, true)
212+
case StringType =>
213+
ArrowType.Utf8.INSTANCE
214+
case DoubleType =>
215+
new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
216+
case FloatType =>
217+
new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
218+
case BooleanType =>
219+
ArrowType.Bool.INSTANCE
220+
case ByteType =>
221+
new ArrowType.Int(8, false)
222+
case StructType(_) =>
223+
ArrowType.Struct.INSTANCE
224+
case _ =>
225+
throw new IllegalArgumentException(s"Unsupported data type")
226+
}
227+
}
228+
}

0 commit comments

Comments
 (0)