Skip to content

Commit 4db8832

Browse files
author
Andrei Ionescu
committed
Remove Unsafe projection from data reader
1 parent eebe06a commit 4db8832

6 files changed

Lines changed: 57 additions & 51 deletions

File tree

spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersFlatDataBenchmark.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.common.collect.Iterables;
2323
import java.io.File;
2424
import java.io.IOException;
25+
import java.util.Collections;
2526
import java.util.List;
2627
import org.apache.avro.generic.GenericData;
2728
import org.apache.iceberg.Files;
@@ -116,7 +117,7 @@ public void tearDownBenchmark() {
116117
public void readUsingIcebergReader(Blackhole blackHole) throws IOException {
117118
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
118119
.project(SCHEMA)
119-
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type))
120+
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type, Collections.emptyMap()))
120121
.build()) {
121122

122123
for (InternalRow row : rows) {
@@ -130,7 +131,7 @@ public void readUsingIcebergReader(Blackhole blackHole) throws IOException {
130131
public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException {
131132
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
132133
.project(SCHEMA)
133-
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type))
134+
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type, Collections.emptyMap()))
134135
.build()) {
135136

136137
Iterable<InternalRow> unsafeRows = Iterables.transform(
@@ -167,7 +168,7 @@ public void readUsingSparkReader(Blackhole blackhole) throws IOException {
167168
public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException {
168169
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
169170
.project(PROJECTED_SCHEMA)
170-
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type))
171+
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type, Collections.emptyMap()))
171172
.build()) {
172173

173174
for (InternalRow row : rows) {
@@ -181,7 +182,7 @@ public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOE
181182
public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException {
182183
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
183184
.project(PROJECTED_SCHEMA)
184-
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type))
185+
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type, Collections.emptyMap()))
185186
.build()) {
186187

187188
Iterable<InternalRow> unsafeRows = Iterables.transform(

spark/src/jmh/java/org/apache/iceberg/spark/data/parquet/SparkParquetReadersNestedDataBenchmark.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.common.collect.Iterables;
2323
import java.io.File;
2424
import java.io.IOException;
25+
import java.util.Collections;
2526
import java.util.List;
2627
import org.apache.avro.generic.GenericData;
2728
import org.apache.iceberg.Files;
@@ -116,7 +117,7 @@ public void tearDownBenchmark() {
116117
public void readUsingIcebergReader(Blackhole blackhole) throws IOException {
117118
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
118119
.project(SCHEMA)
119-
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type))
120+
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type, Collections.emptyMap()))
120121
.build()) {
121122

122123
for (InternalRow row : rows) {
@@ -130,7 +131,7 @@ public void readUsingIcebergReader(Blackhole blackhole) throws IOException {
130131
public void readUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException {
131132
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
132133
.project(SCHEMA)
133-
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type))
134+
.createReaderFunc(type -> SparkParquetReaders.buildReader(SCHEMA, type, Collections.emptyMap()))
134135
.build()) {
135136

136137
Iterable<InternalRow> unsafeRows = Iterables.transform(
@@ -167,7 +168,7 @@ public void readUsingSparkReader(Blackhole blackhole) throws IOException {
167168
public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOException {
168169
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
169170
.project(PROJECTED_SCHEMA)
170-
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type))
171+
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type, Collections.emptyMap()))
171172
.build()) {
172173

173174
for (InternalRow row : rows) {
@@ -181,7 +182,7 @@ public void readWithProjectionUsingIcebergReader(Blackhole blackhole) throws IOE
181182
public void readWithProjectionUsingIcebergReaderUnsafe(Blackhole blackhole) throws IOException {
182183
try (CloseableIterable<InternalRow> rows = Parquet.read(Files.localInput(dataFile))
183184
.project(PROJECTED_SCHEMA)
184-
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type))
185+
.createReaderFunc(type -> SparkParquetReaders.buildReader(PROJECTED_SCHEMA, type, Collections.emptyMap()))
185186
.build()) {
186187

187188
Iterable<InternalRow> unsafeRows = Iterables.transform(

spark/src/main/java/org/apache/iceberg/spark/data/SparkParquetReaders.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,22 @@ private SparkParquetReaders() {
6868

6969
@SuppressWarnings("unchecked")
7070
public static ParquetValueReader<InternalRow> buildReader(Schema expectedSchema,
71-
MessageType fileSchema) {
71+
MessageType fileSchema,
72+
Map<Integer, Object> partitionValues) {
7273
if (ParquetSchemaUtil.hasIds(fileSchema)) {
7374
return (ParquetValueReader<InternalRow>)
7475
TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema,
75-
new ReadBuilder(fileSchema));
76+
new ReadBuilder(fileSchema, partitionValues));
7677
} else {
7778
return (ParquetValueReader<InternalRow>)
7879
TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema,
79-
new FallbackReadBuilder(fileSchema));
80+
new FallbackReadBuilder(fileSchema, partitionValues));
8081
}
8182
}
8283

8384
private static class FallbackReadBuilder extends ReadBuilder {
84-
FallbackReadBuilder(MessageType type) {
85-
super(type);
85+
FallbackReadBuilder(MessageType type, Map<Integer, Object> partitionValues) {
86+
super(type, partitionValues);
8687
}
8788

8889
@Override
@@ -113,9 +114,11 @@ public ParquetValueReader<?> struct(Types.StructType ignored, GroupType struct,
113114

114115
private static class ReadBuilder extends TypeWithSchemaVisitor<ParquetValueReader<?>> {
115116
private final MessageType type;
117+
private final Map<Integer, Object> partitionValues;
116118

117-
ReadBuilder(MessageType type) {
119+
ReadBuilder(MessageType type, Map<Integer, Object> partitionValues) {
118120
this.type = type;
121+
this.partitionValues = partitionValues;
119122
}
120123

121124
@Override
@@ -146,13 +149,18 @@ public ParquetValueReader<?> struct(Types.StructType expected, GroupType struct,
146149
List<Type> types = Lists.newArrayListWithExpectedSize(expectedFields.size());
147150
for (Types.NestedField field : expectedFields) {
148151
int id = field.fieldId();
149-
ParquetValueReader<?> reader = readersById.get(id);
150-
if (reader != null) {
151-
reorderedFields.add(reader);
152-
types.add(typesById.get(id));
153-
} else {
154-
reorderedFields.add(ParquetValueReaders.nulls());
152+
if (partitionValues.containsKey(id)) {
153+
reorderedFields.add(ParquetValueReaders.constant(partitionValues.get(id)));
155154
types.add(null);
155+
} else {
156+
ParquetValueReader<?> reader = readersById.get(id);
157+
if (reader != null) {
158+
reorderedFields.add(reader);
159+
types.add(typesById.get(id));
160+
} else {
161+
reorderedFields.add(ParquetValueReaders.nulls());
162+
types.add(null);
163+
}
156164
}
157165
}
158166

spark/src/main/java/org/apache/iceberg/spark/source/Reader.java

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
import com.google.common.base.Preconditions;
2323
import com.google.common.collect.ImmutableMap;
2424
import com.google.common.collect.Iterables;
25-
import com.google.common.collect.Iterators;
2625
import com.google.common.collect.Lists;
26+
import com.google.common.collect.Maps;
2727
import com.google.common.collect.Sets;
2828
import java.io.Closeable;
2929
import java.io.IOException;
3030
import java.io.Serializable;
3131
import java.nio.ByteBuffer;
32+
import java.util.Collections;
3233
import java.util.Iterator;
3334
import java.util.List;
3435
import java.util.Map;
@@ -69,7 +70,6 @@
6970
import org.apache.spark.sql.catalyst.expressions.Attribute;
7071
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
7172
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
72-
import org.apache.spark.sql.catalyst.expressions.JoinedRow;
7373
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
7474
import org.apache.spark.sql.sources.Filter;
7575
import org.apache.spark.sql.sources.v2.DataSourceOptions;
@@ -397,7 +397,6 @@ private Iterator<InternalRow> open(FileScanTask task) {
397397
// schema or rows returned by readers
398398
Schema finalSchema = expectedSchema;
399399
PartitionSpec spec = task.spec();
400-
401400
Set<Integer> idColumns = Sets.newHashSet();
402401
for (Integer i : spec.identitySourceIds()) {
403402
if (spec.schema().columns().stream()
@@ -407,46 +406,39 @@ private Iterator<InternalRow> open(FileScanTask task) {
407406
}
408407
}
409408

410-
// schema needed for the projection and filtering
411409
StructType sparkType = SparkSchemaUtil.convert(finalSchema);
412410
Schema requiredSchema = SparkSchemaUtil.prune(tableSchema, sparkType, task.residual(), caseSensitive);
413411
boolean hasJoinedPartitionColumns = !idColumns.isEmpty();
414412
boolean hasExtraFilterColumns = requiredSchema.columns().size() != finalSchema.columns().size();
415413

416-
Schema iterSchema;
417414
Iterator<InternalRow> iter;
418415

419416
if (hasJoinedPartitionColumns) {
420-
// schema used to read data files
421-
Schema readSchema = TypeUtil.selectNot(requiredSchema, idColumns);
422417
Schema partitionSchema = TypeUtil.select(requiredSchema, idColumns);
423418
PartitionRowConverter convertToRow = new PartitionRowConverter(partitionSchema, spec);
424-
JoinedRow joined = new JoinedRow();
419+
GenericInternalRow partition = (GenericInternalRow) convertToRow.apply(file.partition());
425420

426-
InternalRow partition = convertToRow.apply(file.partition());
427-
joined.withRight(partition);
421+
Map<Integer, Object> partitionValueMap = Maps.newHashMap();
422+
Map<String, Integer> partitionSpecFieldIndexMap = Maps.newHashMap();
423+
for (int i = 0; i < spec.fields().size(); i++) {
424+
partitionSpecFieldIndexMap.put(spec.fields().get(i).name(), i);
425+
}
428426

429-
// create joined rows and project from the joined schema to the final schema
430-
iterSchema = TypeUtil.join(readSchema, partitionSchema);
431-
iter = Iterators.transform(open(task, readSchema), joined::withLeft);
427+
for (Types.NestedField field : partitionSchema.columns()) {
428+
int partitionIndex = partitionSpecFieldIndexMap.get(field.name());
429+
partitionValueMap.put(field.fieldId(), partition.genericGet(partitionIndex));
430+
}
432431

432+
iter = open(task, finalSchema, partitionValueMap);
433433
} else if (hasExtraFilterColumns) {
434-
// add projection to the final schema
435-
iterSchema = requiredSchema;
436-
iter = open(task, requiredSchema);
437-
434+
iter = open(task, requiredSchema, Collections.emptyMap());
438435
} else {
439-
// return the base iterator
440-
iterSchema = finalSchema;
441-
iter = open(task, finalSchema);
436+
iter = open(task, finalSchema, Collections.emptyMap());
442437
}
443-
444-
// TODO: remove the projection by reporting the iterator's schema back to Spark
445-
return Iterators.transform(iter,
446-
APPLY_PROJECTION.bind(projection(finalSchema, iterSchema))::invoke);
438+
return iter;
447439
}
448440

449-
private Iterator<InternalRow> open(FileScanTask task, Schema readSchema) {
441+
private Iterator<InternalRow> open(FileScanTask task, Schema readSchema, Map<Integer, Object> partitionValues) {
450442
CloseableIterable<InternalRow> iter;
451443
if (task.isDataTask()) {
452444
iter = newDataIterable(task.asDataTask(), readSchema);
@@ -457,7 +449,7 @@ private Iterator<InternalRow> open(FileScanTask task, Schema readSchema) {
457449

458450
switch (task.file().format()) {
459451
case PARQUET:
460-
iter = newParquetIterable(location, task, readSchema);
452+
iter = newParquetIterable(location, task, readSchema, partitionValues);
461453
break;
462454

463455
case AVRO:
@@ -513,12 +505,14 @@ private CloseableIterable<InternalRow> newAvroIterable(InputFile location,
513505
}
514506

515507
private CloseableIterable<InternalRow> newParquetIterable(InputFile location,
516-
FileScanTask task,
517-
Schema readSchema) {
508+
FileScanTask task,
509+
Schema readSchema,
510+
Map<Integer, Object> partitionValues) {
511+
518512
return Parquet.read(location)
519513
.project(readSchema)
520514
.split(task.start(), task.length())
521-
.createReaderFunc(fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema))
515+
.createReaderFunc(fileSchema -> SparkParquetReaders.buildReader(readSchema, fileSchema, partitionValues))
522516
.filter(task.residual())
523517
.caseSensitive(caseSensitive)
524518
.build();

spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetReader.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.File;
2323
import java.io.IOException;
24+
import java.util.Collections;
2425
import java.util.Iterator;
2526
import java.util.List;
2627
import org.apache.avro.generic.GenericData;
@@ -57,7 +58,7 @@ protected void writeAndValidate(Schema schema) throws IOException {
5758

5859
try (CloseableIterable<InternalRow> reader = Parquet.read(Files.localInput(testFile))
5960
.project(schema)
60-
.createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type))
61+
.createReaderFunc(type -> SparkParquetReaders.buildReader(schema, type, Collections.emptyMap()))
6162
.build()) {
6263
Iterator<InternalRow> rows = reader.iterator();
6364
for (int i = 0; i < expected.size(); i += 1) {

spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.File;
2323
import java.io.IOException;
24+
import java.util.Collections;
2425
import java.util.Iterator;
2526
import org.apache.iceberg.Files;
2627
import org.apache.iceberg.Schema;
@@ -85,7 +86,7 @@ public void testCorrectness() throws IOException {
8586

8687
try (CloseableIterable<InternalRow> reader = Parquet.read(Files.localInput(testFile))
8788
.project(COMPLEX_SCHEMA)
88-
.createReaderFunc(type -> SparkParquetReaders.buildReader(COMPLEX_SCHEMA, type))
89+
.createReaderFunc(type -> SparkParquetReaders.buildReader(COMPLEX_SCHEMA, type, Collections.emptyMap()))
8990
.build()) {
9091
Iterator<InternalRow> expected = records.iterator();
9192
Iterator<InternalRow> rows = reader.iterator();

0 commit comments

Comments
 (0)