Skip to content

Commit

Permalink
[Kernel][Defaults] Add support for pushing IS [NOT] NULL into Parquet…
Browse files Browse the repository at this point in the history
… reader (#3292)

## Description
Allows pushing down predicates `IS NULL` and `IS NOT NULL` into the
default Parquet reader. Helps prune the number of row groups read based
on the predicates.

## How was this patch tested?
Unit tests
  • Loading branch information
vkorukanti authored Jun 24, 2024
1 parent 5ea073b commit 9f03492
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.filter2.compat.FilterCompat.Filter;
import org.apache.parquet.filter2.predicate.FilterApi;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.filter2.predicate.*;
import org.apache.parquet.filter2.predicate.Operators.*;
import org.apache.parquet.hadoop.metadata.ColumnPath;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.schema.*;
import org.apache.parquet.schema.LogicalTypeAnnotation.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.parquet.filter2.predicate.FilterApi.*;

import io.delta.kernel.expressions.*;
import io.delta.kernel.expressions.Column;
Expand Down Expand Up @@ -154,6 +154,12 @@ private static Optional<FilterPredicate> convertToParquetFilter(
return convertAndToParquetFilter(parquetFieldMap, deltaPredicate);
case "or":
return convertOrToParquetFilter(parquetFieldMap, deltaPredicate);
case "is_null":
return convertIsNullIsNotNull(
parquetFieldMap, deltaPredicate, false /* isNotNull */);
case "is_not_null":
return convertIsNullIsNotNull(
parquetFieldMap, deltaPredicate, true /* isNotNull */);
default:
return visitUnsupported(deltaPredicate, name + " is not a supported predicate.");
}
Expand Down Expand Up @@ -206,13 +212,13 @@ private static Optional<FilterPredicate> convertComparatorToParquetFilter(

switch (parquetType.getPrimitiveTypeName()) {
case BOOLEAN:
BooleanColumn booleanColumn = FilterApi.booleanColumn(columnPath);
BooleanColumn booleanColumn = booleanColumn(columnPath);
if ("=".equals(comparator)) { // Only = is supported for boolean
return Optional.of(FilterApi.eq(booleanColumn, getBoolean(literal)));
}
break;
case INT32:
IntColumn intColumn = FilterApi.intColumn(columnPath);
IntColumn intColumn = intColumn(columnPath);
switch (comparator) {
case "=":
return Optional.of(FilterApi.eq(intColumn, getInt(literal)));
Expand All @@ -227,7 +233,7 @@ private static Optional<FilterPredicate> convertComparatorToParquetFilter(
}
break;
case INT64:
LongColumn longColumn = FilterApi.longColumn(columnPath);
LongColumn longColumn = longColumn(columnPath);
switch (comparator) {
case "=":
return Optional.of(FilterApi.eq(longColumn, getLong(literal)));
Expand All @@ -242,7 +248,7 @@ private static Optional<FilterPredicate> convertComparatorToParquetFilter(
}
break;
case FLOAT:
FloatColumn floatColumn = FilterApi.floatColumn(columnPath);
FloatColumn floatColumn = floatColumn(columnPath);
switch (comparator) {
case "=":
return Optional.of(FilterApi.eq(floatColumn, getFloat(literal)));
Expand All @@ -257,7 +263,7 @@ private static Optional<FilterPredicate> convertComparatorToParquetFilter(
}
break;
case DOUBLE:
DoubleColumn doubleColumn = FilterApi.doubleColumn(columnPath);
DoubleColumn doubleColumn = doubleColumn(columnPath);
switch (comparator) {
case "=":
return Optional.of(FilterApi.eq(doubleColumn, getDouble(literal)));
Expand All @@ -272,7 +278,7 @@ private static Optional<FilterPredicate> convertComparatorToParquetFilter(
}
break;
case BINARY:
BinaryColumn binaryColumn = FilterApi.binaryColumn(columnPath);
BinaryColumn binaryColumn = binaryColumn(columnPath);
Binary binary = getBinary(literal);
switch (comparator) {
case "=":
Expand Down Expand Up @@ -334,6 +340,56 @@ private static Optional<FilterPredicate> convertAndToParquetFilter(
return rightFilter;
}

private static Optional<FilterPredicate> convertIsNullIsNotNull(
Map<Column, ParquetField> parquetFieldMap,
Predicate deltaPredicate,
boolean isNotNull) {
Expression child = getUnaryChild(deltaPredicate);
if (!(child instanceof Column)) {
return visitUnsupported(deltaPredicate, "IS NULL predicate must have a column input.");
}

Column column = (Column) child;
ParquetField parquetField = parquetFieldMap.get(column);
if (parquetField == null) {
return visitUnsupported(
deltaPredicate,
"Column used in predicate does not exist in the parquet file.");
}

String columnPath = ColumnPath.get(column.getNames()).toDotString();
// Parquet filter keeps records if their value is equal to the provided value.
// Nulls are treated the same way the java programming language does.
// For example: eq(column, null) will keep all records whose value is null. eq(column, 7)
// will keep all records whose value is 7, and will drop records whose value is null
// NOTE: this is different from how some query languages handle null.
switch (parquetField.primitiveType.getPrimitiveTypeName()) {
case BOOLEAN:
return createIsNullOrIsNotNullPredicate(booleanColumn(columnPath), isNotNull);
case INT32:
return createIsNullOrIsNotNullPredicate(intColumn(columnPath), isNotNull);
case INT64:
return createIsNullOrIsNotNullPredicate(longColumn(columnPath), isNotNull);
case FLOAT:
return createIsNullOrIsNotNullPredicate(floatColumn(columnPath), isNotNull);
case DOUBLE:
return createIsNullOrIsNotNullPredicate(doubleColumn(columnPath), isNotNull);
case BINARY:
return createIsNullOrIsNotNullPredicate(binaryColumn(columnPath), isNotNull);
default:
return visitUnsupported(
deltaPredicate,
"Unsupported column type: " + parquetField.primitiveType);
}
}

private static <T extends Comparable<T>, C extends Operators.Column<T> & SupportsEqNotEq>
Optional<FilterPredicate> createIsNullOrIsNotNullPredicate(
C column,
boolean isNotNull) {
return Optional.of(isNotNull ? FilterApi.notEq(column, null) : FilterApi.eq(column, null));
}

private static Optional<FilterPredicate> visitUnsupported(
Predicate predicate,
String message) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,70 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite

private def generateRowsGroup(rowGroupIdx: Int): Seq[Row] = {
def values(rowId: Int): Seq[Any] = {
// One of the columns in each row group is all nulls or all non-nulls depending on
// the [[rowGroupIdx]]. This helps to verify the test results for `is null` and
// `is not null` pushdown
Seq(
if (rowId % 72 != 0) rowId.byteValue() else null,
if (rowId % 56 != 0) rowId.shortValue() else null,
if (rowId % 23 != 0) rowId else null,
if (rowId % 25 != 0) (rowId + 1).longValue() else null,
if (rowId % 28 != 0) (rowId + 0.125).floatValue() else null,
if (rowId % 54 != 0) (rowId + 0.000001).doubleValue() else null,
if (rowId % 57 != 0) "%05d".format(rowId) else null,
if (rowId % 57 != 0) "%050d".format(rowId) else null, // truncated stats
if (rowId % 59 != 0) "%06d".format(rowId).getBytes else null,
if (rowId % 59 != 0) "%060d".format(rowId).getBytes else null, // truncated stats
// alternate between true and false for each row group
(rowId / 100) % 2 == 0,
if (rowId % 61 != 0) new Date(rowId * 86400000L /* millis in a day */) else null
// byteCol
if (rowGroupIdx == 0) null /* all nulls */
else if (rowGroupIdx == 11) rowId.byteValue() /* all non-nulls */
else (if (rowId % 72 != 0) rowId.byteValue() else null), /* mix of nulls and non-nulls */

// shortCol
if (rowGroupIdx == 1) null
else if (rowGroupIdx == 10) rowId.shortValue()
else (if (rowId % 56 != 0) rowId.shortValue() else null),

// intCol
if (rowGroupIdx == 2) null
else if (rowGroupIdx == 9) rowId
else (if (rowId % 23 != 0) rowId else null),

// longCol
if (rowGroupIdx == 3) null
else if (rowGroupIdx == 8) (rowId + 1).longValue()
else (if (rowId % 25 != 0) (rowId + 1).longValue() else null),

// floatCol
if (rowGroupIdx == 4) null
else if (rowGroupIdx == 7) (rowId + 0.125).floatValue()
else (if (rowId % 28 != 0) (rowId + 0.125).floatValue() else null),

// doubleCol
if (rowGroupIdx == 5) null
else if (rowGroupIdx == 6) (rowId + 0.000001).doubleValue()
else (if (rowId % 54 != 0) (rowId + 0.000001).doubleValue() else null),

// stringCol
if (rowGroupIdx == 6) null
else if (rowGroupIdx == 5) "%05d".format(rowId)
else (if (rowId % 57 != 0) "%05d".format(rowId) else null),

// truncatedStringCol - stats will be truncated as the value is too long
if (rowGroupIdx == 7) null
else if (rowGroupIdx == 4) "%050d".format(rowId)
else (if (rowId % 57 != 0) "%050d".format(rowId) else null),

// binaryCol
if (rowGroupIdx == 8) null
else if (rowGroupIdx == 3) "%06d".format(rowId).getBytes
else (if (rowId % 59 != 0) "%06d".format(rowId).getBytes else null),

// truncatedBinaryCol - stats will be truncated as the value is too long
if (rowGroupIdx == 9) null
else if (rowGroupIdx == 2) "%060d".format(rowId).getBytes
else (if (rowId % 59 != 0) "%060d".format(rowId).getBytes else null),

// booleanCol
if (rowGroupIdx == 10) null
else if (rowGroupIdx == 1) rowId % 2 == 0
// alternative between true and false for each row group
else (if (rowId % 29 != 0) rowGroupIdx % 2 == 0 else null),

// dateCol
if (rowGroupIdx == 11) null
else if (rowGroupIdx == 0) new Date(rowId * 86400000L)
else (if (rowId % 61 != 0) new Date(rowId * 86400000L) else null)
)
}

Expand Down Expand Up @@ -162,7 +212,7 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite
(
lt(col("floatCol"), ofFloat(1000.0f)),
lt(col("nested", "floatCol"), ofFloat(1000.0f)),
Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // expected row groups
Seq(0, 1, 2, 3, 5, 6, 7, 8, 9) // expected row groups - row group 4 has all nulls
),
// filter on double type column
(
Expand All @@ -174,7 +224,9 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite
(
eq(col("booleanCol"), ofBoolean(true)),
eq(col("nested", "booleanCol"), ofBoolean(true)),
Seq(0, 2, 4, 6, 8, 10, 12, 14, 16, 18) // expected row groups
// expected row groups
// 1 has mix of true/false (included), 10 has all nulls (not included)
Seq(0, 1, 2, 4, 6, 8, 12, 14, 16, 18)
),
// filter on date type column
(
Expand All @@ -200,7 +252,8 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite
(
gte(col("truncatedStringCol"), ofString("%050d".format(300))),
gte(col("nested", "truncatedStringCol"), ofString("%050d".format(300))),
Seq(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) // expected row groups
// expected row groups
Seq(3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) // 7 has all nulls
),
// filter on truncated stats binary type column
(
Expand All @@ -222,6 +275,56 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite
}
}

// IS NULL and IS NOT NULL tests
Seq(
// (columnName, row groups with all nulls, row groups with all non-nulls)
("byteCol", Seq(0), Seq(11)), // int type column
("shortCol", Seq(1), Seq(10)), // short type column
("intCol", Seq(2), Seq(9)), // int type column
("longCol", Seq(3), Seq(8)), // long type column
("floatCol", Seq(4), Seq(7)), // float type column
("doubleCol", Seq(5), Seq(6)), // double type column
("stringCol", Seq(6), Seq(5)), // string type column
("truncatedStringCol", Seq(7), Seq(4)), // truncatedStringCol type column
("binaryCol", Seq(8), Seq(3)), // binary type column
("truncatedBinaryCol", Seq(9), Seq(2)), // truncatedBinaryCol type column
("booleanCol", Seq(10), Seq(1)), // boolean type column
("dateCol", Seq(11), Seq(0)) // date type column
).foreach {
// Test table has 20 row groups, each with 100 rows.
case (colName, allNullsRowGroups, allNonNullsRowGroups) =>
// Test predicate on both top-level and nested columns
Seq(col(colName), col("nested", colName)).foreach { column =>
val isNullFilter = isNull(column)
test(s"filter pushdown: $isNullFilter") {
val actualData = readUsingKernel(testParquetTable, isNullFilter)
val expOutputRowCount = 100 * (20 - 1) // 100 rows per row group

// we get everything expect the rowgroup that has all non-nulls
val expRowGroups = (0 until 20).filter(!allNonNullsRowGroups.contains(_))
assert(actualData.size === expOutputRowCount, s"predicate: $isNullFilter")
checkAnswer(actualData, generateExpData(expRowGroups))

// not (col is null) should return all row groups exception the one with all nulls
assertNot(isNullFilter, (0 until 20).filter(!allNullsRowGroups.contains(_)))
}

val isNotNullFilter = isNotNull(column)
test(s"filter pushdown: $isNotNullFilter") {
val actualData = readUsingKernel(testParquetTable, isNotNullFilter)
val expOutputRowCount = 100 * (20 - 1) // 100 rows per row group

// we get everything expect the rowgroup that has all nulls
val expRowGroups = (0 until 20).filter(!allNullsRowGroups.contains(_))
assert(actualData.size === expOutputRowCount, s"predicate: $isNotNullFilter")
checkAnswer(actualData, generateExpData(expRowGroups))

// not (col is not null) should return all row groups exception the one with all non-nulls
assertNot(isNotNullFilter, (0 until 20).filter(!allNonNullsRowGroups.contains(_)))
}
}
}

test("for a column that doesn't exist in the table") {
val testPredicate = predicate("=", col("nonExistentCol"), ofInt(20))
assertConvertedFilterIsEmpty(testPredicate, testParquetTable)
Expand Down Expand Up @@ -303,11 +406,26 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite
checkAnswer(actData, generateExpData(Seq(15)))
}

test("not support") {
val predicate = not(eq(col("booleanCol"), ofBoolean(true)))
test("not support on gt") {
val predicate = not(gt(col("intCol"), ofInt(950)))
val actData = readUsingKernel(testParquetTable, predicate)
// every odd rowgroup has true for booleanCol
checkAnswer(actData, generateExpData(Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19)))

// rowgroups until 9 could have values <= 950
// rowgroup 2 has all nulls, so it won't be included in the result
val expRowGroups = Seq(0, 1, 3, 4, 5, 6, 7, 8, 9)
val expOutputRowCount = expRowGroups.length * 100 // 100 rows per row group
assert(actData.size === expOutputRowCount, s"predicate: $predicate")

checkAnswer(actData, generateExpData(expRowGroups))
}

test("not support on equality") {
val predicate = not(eq(col("longCol"), ofLong(768)))
val actData = readUsingKernel(testParquetTable, predicate)
// rowgroup 3 has all nulls, so it will be included in the results as
// Parquet equality filter is not null safe
// every other group has value that is not 768
checkAnswer(actData, generateExpData(Seq.range(0, 20)))
}

test("doesn't work on the repeated columns") {
Expand All @@ -322,4 +440,13 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite

checkAnswer(actResult, expResult)
}

/** Test the `not(predicate)` returns expected rowgroups */
private def assertNot(predicate: Predicate, expRowGroups: Seq[Int]): Unit = {
val notPredicate = not(predicate)
val actualData = readUsingKernel(testParquetTable, notPredicate)
val expOutputRowCount = expRowGroups.length * 100 // 100 rows per row group
assert(actualData.size === expOutputRowCount, s"predicate: $notPredicate")
checkAnswer(actualData, generateExpData(expRowGroups))
}
}

0 comments on commit 9f03492

Please # to comment.