diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java index e57f3cbcee..80debd7b4f 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/InternalParquetRecordReader.java @@ -180,12 +180,14 @@ public void initialize(ParquetFileReader reader, ParquetReadOptions options) { this.columnIOFactory = new ColumnIOFactory(parquetFileMetadata.getCreatedBy()); this.requestedSchema = readContext.getRequestedSchema(); this.columnCount = requestedSchema.getPaths().size(); + // Setting the projection schema before running any filtering (e.g. getting filtered record count) + // because projection impacts filtering + reader.setRequestedSchema(requestedSchema); this.recordConverter = readSupport.prepareForRead(conf, fileMetadata, fileSchema, readContext); this.strictTypeChecking = options.isEnabled(STRICT_TYPE_CHECKING, true); this.total = reader.getFilteredRecordCount(); this.unmaterializableRecordCounter = new UnmaterializableRecordCounter(options, total); this.filterRecords = options.useRecordFilter(); - reader.setRequestedSchema(requestedSchema); LOG.info("RecordReader initialized will read a total of {} records.", total); } @@ -201,13 +203,15 @@ public void initialize(ParquetFileReader reader, Configuration configuration) this.columnIOFactory = new ColumnIOFactory(parquetFileMetadata.getCreatedBy()); this.requestedSchema = readContext.getRequestedSchema(); this.columnCount = requestedSchema.getPaths().size(); + // Setting the projection schema before running any filtering (e.g. getting filtered record count) + // because projection impacts filtering + reader.setRequestedSchema(requestedSchema); this.recordConverter = readSupport.prepareForRead( configuration, fileMetadata, fileSchema, readContext); this.strictTypeChecking = configuration.getBoolean(STRICT_TYPE_CHECKING, true); this.total = reader.getFilteredRecordCount(); this.unmaterializableRecordCounter = new UnmaterializableRecordCounter(configuration, total); this.filterRecords = configuration.getBoolean(RECORD_FILTERING_ENABLED, true); - reader.setRequestedSchema(requestedSchema); LOG.info("RecordReader initialized will read a total of {} records.", total); } diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java index 18ddca0d96..6355f35c3c 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java @@ -199,6 +199,10 @@ public int hashCode() { public String toString() { return "User [id=" + id + ", name=" + name + ", phoneNumbers=" + phoneNumbers + ", location=" + location + "]"; } + + public User cloneWithName(String name) { + return new User(id, name, phoneNumbers, location); + } } public static SimpleGroup groupFromUser(User user) { @@ -257,6 +261,10 @@ private static Location getLocation(Group location) { } private static boolean isNull(Group group, String field) { + // Use null value if the field is not in the group schema + if (!group.getType().containsField(field)) { + return true; + } int repetition = group.getFieldRepetitionCount(field); if (repetition == 0) { return true; diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java index ccb6a03d51..c18212e026 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestColumnIndexFiltering.java @@ -19,6 +19,7 @@ package org.apache.parquet.hadoop; import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.toList; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; @@ -33,6 +34,12 @@ import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.predicate.LogicalInverter.invert; import static org.apache.parquet.hadoop.ParquetFileWriter.Mode.OVERWRITE; +import static org.apache.parquet.schema.LogicalTypeAnnotation.stringType; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.parquet.schema.Types.optional; +import static org.apache.parquet.schema.Types.required; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -64,9 +71,12 @@ import org.apache.parquet.filter2.recordlevel.PhoneBookWriter.Location; import org.apache.parquet.filter2.recordlevel.PhoneBookWriter.PhoneNumber; import org.apache.parquet.filter2.recordlevel.PhoneBookWriter.User; +import org.apache.parquet.hadoop.api.ReadSupport; import org.apache.parquet.hadoop.example.ExampleParquetWriter; import org.apache.parquet.hadoop.example.GroupReadSupport; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Types; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -87,6 +97,19 @@ public class TestColumnIndexFiltering { private static final List DATA = Collections.unmodifiableList(generateData(10000)); private static final Path FILE_V1 = createTempFile(); private static final Path FILE_V2 = createTempFile(); + private static final MessageType SCHEMA_WITHOUT_NAME = Types.buildMessage() + .required(INT64).named("id") + .optionalGroup() + .addField(optional(DOUBLE).named("lon")) + .addField(optional(DOUBLE).named("lat")) + .named("location") + .optionalGroup() + .repeatedGroup() + .addField(required(INT64).named("number")) + .addField(optional(BINARY).as(stringType()).named("kind")) + .named("phone") + .named("phoneNumbers") + .named("user_without_name"); @Parameters public static Collection params() { @@ -199,6 +222,16 @@ private List readUsers(Filter filter, boolean useOtherFiltering, boolean u .useColumnIndexFilter(useColumnIndexFilter)); } + private List readUsersWithProjection(Filter filter, MessageType schema, boolean useOtherFiltering, boolean useColumnIndexFilter) throws IOException { + return PhoneBookWriter.readUsers(ParquetReader.builder(new GroupReadSupport(), file) + .withFilter(filter) + .useDictionaryFilter(useOtherFiltering) + .useStatsFilter(useOtherFiltering) + .useRecordFilter(useOtherFiltering) + .useColumnIndexFilter(useColumnIndexFilter) + .set(ReadSupport.PARQUET_READ_SCHEMA, schema.toString())); + } + // Assumes that both lists are in the same order private static void assertContains(Stream expected, List actual) { Iterator expIt = expected.iterator(); @@ -441,4 +474,21 @@ record -> record.getId() == 1234, or(eq(longColumn("id"), 1234l), userDefined(longColumn("not-existing-long"), new IsDivisibleBy(1)))); } + + @Test + public void testFilteringWithProjection() throws IOException { + // All rows shall be retrieved because all values in column 'name' shall be handled as null values + assertEquals( + DATA.stream().map(user -> user.cloneWithName(null)).collect(toList()), + readUsersWithProjection(FilterCompat.get(eq(binaryColumn("name"), null)), SCHEMA_WITHOUT_NAME, true, true)); + + // Column index filter shall drop all pages because all values in column 'name' shall be handled as null values + assertEquals( + emptyList(), + readUsersWithProjection(FilterCompat.get(notEq(binaryColumn("name"), null)), SCHEMA_WITHOUT_NAME, false, true)); + assertEquals( + emptyList(), + readUsersWithProjection(FilterCompat.get(userDefined(binaryColumn("name"), NameStartsWithVowel.class)), + SCHEMA_WITHOUT_NAME, false, true)); + } }