diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFilterUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFilterUtils.java index b9365bb848d..0e1a2b9b947 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFilterUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetFilterUtils.java @@ -19,8 +19,7 @@ import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.filter2.compat.FilterCompat.Filter; -import org.apache.parquet.filter2.predicate.FilterApi; -import org.apache.parquet.filter2.predicate.FilterPredicate; +import org.apache.parquet.filter2.predicate.*; import org.apache.parquet.filter2.predicate.Operators.*; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.io.api.Binary; @@ -28,6 +27,7 @@ import org.apache.parquet.schema.LogicalTypeAnnotation.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.parquet.filter2.predicate.FilterApi.*; import io.delta.kernel.expressions.*; import io.delta.kernel.expressions.Column; @@ -154,6 +154,12 @@ private static Optional convertToParquetFilter( return convertAndToParquetFilter(parquetFieldMap, deltaPredicate); case "or": return convertOrToParquetFilter(parquetFieldMap, deltaPredicate); + case "is_null": + return convertIsNullIsNotNull( + parquetFieldMap, deltaPredicate, false /* isNotNull */); + case "is_not_null": + return convertIsNullIsNotNull( + parquetFieldMap, deltaPredicate, true /* isNotNull */); default: return visitUnsupported(deltaPredicate, name + " is not a supported predicate."); } @@ -206,13 +212,13 @@ private static Optional convertComparatorToParquetFilter( switch (parquetType.getPrimitiveTypeName()) { case BOOLEAN: - BooleanColumn booleanColumn = FilterApi.booleanColumn(columnPath); + BooleanColumn booleanColumn = booleanColumn(columnPath); if ("=".equals(comparator)) { // Only = is supported for boolean return Optional.of(FilterApi.eq(booleanColumn, getBoolean(literal))); } break; case INT32: - IntColumn intColumn = FilterApi.intColumn(columnPath); + IntColumn intColumn = intColumn(columnPath); switch (comparator) { case "=": return Optional.of(FilterApi.eq(intColumn, getInt(literal))); @@ -227,7 +233,7 @@ private static Optional convertComparatorToParquetFilter( } break; case INT64: - LongColumn longColumn = FilterApi.longColumn(columnPath); + LongColumn longColumn = longColumn(columnPath); switch (comparator) { case "=": return Optional.of(FilterApi.eq(longColumn, getLong(literal))); @@ -242,7 +248,7 @@ private static Optional convertComparatorToParquetFilter( } break; case FLOAT: - FloatColumn floatColumn = FilterApi.floatColumn(columnPath); + FloatColumn floatColumn = floatColumn(columnPath); switch (comparator) { case "=": return Optional.of(FilterApi.eq(floatColumn, getFloat(literal))); @@ -257,7 +263,7 @@ private static Optional convertComparatorToParquetFilter( } break; case DOUBLE: - DoubleColumn doubleColumn = FilterApi.doubleColumn(columnPath); + DoubleColumn doubleColumn = doubleColumn(columnPath); switch (comparator) { case "=": return Optional.of(FilterApi.eq(doubleColumn, getDouble(literal))); @@ -272,7 +278,7 @@ private static Optional convertComparatorToParquetFilter( } break; case BINARY: - BinaryColumn binaryColumn = FilterApi.binaryColumn(columnPath); + BinaryColumn binaryColumn = binaryColumn(columnPath); Binary binary = getBinary(literal); switch (comparator) { case "=": @@ -334,6 +340,56 @@ private static Optional convertAndToParquetFilter( return rightFilter; } + private static Optional convertIsNullIsNotNull( + Map parquetFieldMap, + Predicate deltaPredicate, + boolean isNotNull) { + Expression child = getUnaryChild(deltaPredicate); + if (!(child instanceof Column)) { + return visitUnsupported(deltaPredicate, "IS NULL predicate must have a column input."); + } + + Column column = (Column) child; + ParquetField parquetField = parquetFieldMap.get(column); + if (parquetField == null) { + return visitUnsupported( + deltaPredicate, + "Column used in predicate does not exist in the parquet file."); + } + + String columnPath = ColumnPath.get(column.getNames()).toDotString(); + // Parquet filter keeps records if their value is equal to the provided value. + // Nulls are treated the same way the java programming language does. + // For example: eq(column, null) will keep all records whose value is null. eq(column, 7) + // will keep all records whose value is 7, and will drop records whose value is null + // NOTE: this is different from how some query languages handle null. + switch (parquetField.primitiveType.getPrimitiveTypeName()) { + case BOOLEAN: + return createIsNullOrIsNotNullPredicate(booleanColumn(columnPath), isNotNull); + case INT32: + return createIsNullOrIsNotNullPredicate(intColumn(columnPath), isNotNull); + case INT64: + return createIsNullOrIsNotNullPredicate(longColumn(columnPath), isNotNull); + case FLOAT: + return createIsNullOrIsNotNullPredicate(floatColumn(columnPath), isNotNull); + case DOUBLE: + return createIsNullOrIsNotNullPredicate(doubleColumn(columnPath), isNotNull); + case BINARY: + return createIsNullOrIsNotNullPredicate(binaryColumn(columnPath), isNotNull); + default: + return visitUnsupported( + deltaPredicate, + "Unsupported column type: " + parquetField.primitiveType); + } + } + + private static , C extends Operators.Column & SupportsEqNotEq> + Optional createIsNullOrIsNotNullPredicate( + C column, + boolean isNotNull) { + return Optional.of(isNotNull ? FilterApi.notEq(column, null) : FilterApi.eq(column, null)); + } + private static Optional visitUnsupported( Predicate predicate, String message) { diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetReaderPredicatePushdownSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetReaderPredicatePushdownSuite.scala index 18853fa9dd0..bc55055f210 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetReaderPredicatePushdownSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/parquet/ParquetReaderPredicatePushdownSuite.scala @@ -92,20 +92,70 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite private def generateRowsGroup(rowGroupIdx: Int): Seq[Row] = { def values(rowId: Int): Seq[Any] = { + // One of the columns in each row group is all nulls or all non-nulls depending on + // the [[rowGroupIdx]]. This helps to verify the test results for `is null` and + // `is not null` pushdown Seq( - if (rowId % 72 != 0) rowId.byteValue() else null, - if (rowId % 56 != 0) rowId.shortValue() else null, - if (rowId % 23 != 0) rowId else null, - if (rowId % 25 != 0) (rowId + 1).longValue() else null, - if (rowId % 28 != 0) (rowId + 0.125).floatValue() else null, - if (rowId % 54 != 0) (rowId + 0.000001).doubleValue() else null, - if (rowId % 57 != 0) "%05d".format(rowId) else null, - if (rowId % 57 != 0) "%050d".format(rowId) else null, // truncated stats - if (rowId % 59 != 0) "%06d".format(rowId).getBytes else null, - if (rowId % 59 != 0) "%060d".format(rowId).getBytes else null, // truncated stats - // alternate between true and false for each row group - (rowId / 100) % 2 == 0, - if (rowId % 61 != 0) new Date(rowId * 86400000L /* millis in a day */) else null + // byteCol + if (rowGroupIdx == 0) null /* all nulls */ + else if (rowGroupIdx == 11) rowId.byteValue() /* all non-nulls */ + else (if (rowId % 72 != 0) rowId.byteValue() else null), /* mix of nulls and non-nulls */ + + // shortCol + if (rowGroupIdx == 1) null + else if (rowGroupIdx == 10) rowId.shortValue() + else (if (rowId % 56 != 0) rowId.shortValue() else null), + + // intCol + if (rowGroupIdx == 2) null + else if (rowGroupIdx == 9) rowId + else (if (rowId % 23 != 0) rowId else null), + + // longCol + if (rowGroupIdx == 3) null + else if (rowGroupIdx == 8) (rowId + 1).longValue() + else (if (rowId % 25 != 0) (rowId + 1).longValue() else null), + + // floatCol + if (rowGroupIdx == 4) null + else if (rowGroupIdx == 7) (rowId + 0.125).floatValue() + else (if (rowId % 28 != 0) (rowId + 0.125).floatValue() else null), + + // doubleCol + if (rowGroupIdx == 5) null + else if (rowGroupIdx == 6) (rowId + 0.000001).doubleValue() + else (if (rowId % 54 != 0) (rowId + 0.000001).doubleValue() else null), + + // stringCol + if (rowGroupIdx == 6) null + else if (rowGroupIdx == 5) "%05d".format(rowId) + else (if (rowId % 57 != 0) "%05d".format(rowId) else null), + + // truncatedStringCol - stats will be truncated as the value is too long + if (rowGroupIdx == 7) null + else if (rowGroupIdx == 4) "%050d".format(rowId) + else (if (rowId % 57 != 0) "%050d".format(rowId) else null), + + // binaryCol + if (rowGroupIdx == 8) null + else if (rowGroupIdx == 3) "%06d".format(rowId).getBytes + else (if (rowId % 59 != 0) "%06d".format(rowId).getBytes else null), + + // truncatedBinaryCol - stats will be truncated as the value is too long + if (rowGroupIdx == 9) null + else if (rowGroupIdx == 2) "%060d".format(rowId).getBytes + else (if (rowId % 59 != 0) "%060d".format(rowId).getBytes else null), + + // booleanCol + if (rowGroupIdx == 10) null + else if (rowGroupIdx == 1) rowId % 2 == 0 + // alternative between true and false for each row group + else (if (rowId % 29 != 0) rowGroupIdx % 2 == 0 else null), + + // dateCol + if (rowGroupIdx == 11) null + else if (rowGroupIdx == 0) new Date(rowId * 86400000L) + else (if (rowId % 61 != 0) new Date(rowId * 86400000L) else null) ) } @@ -162,7 +212,7 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite ( lt(col("floatCol"), ofFloat(1000.0f)), lt(col("nested", "floatCol"), ofFloat(1000.0f)), - Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // expected row groups + Seq(0, 1, 2, 3, 5, 6, 7, 8, 9) // expected row groups - row group 4 has all nulls ), // filter on double type column ( @@ -174,7 +224,9 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite ( eq(col("booleanCol"), ofBoolean(true)), eq(col("nested", "booleanCol"), ofBoolean(true)), - Seq(0, 2, 4, 6, 8, 10, 12, 14, 16, 18) // expected row groups + // expected row groups + // 1 has mix of true/false (included), 10 has all nulls (not included) + Seq(0, 1, 2, 4, 6, 8, 12, 14, 16, 18) ), // filter on date type column ( @@ -200,7 +252,8 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite ( gte(col("truncatedStringCol"), ofString("%050d".format(300))), gte(col("nested", "truncatedStringCol"), ofString("%050d".format(300))), - Seq(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) // expected row groups + // expected row groups + Seq(3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19) // 7 has all nulls ), // filter on truncated stats binary type column ( @@ -222,6 +275,56 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite } } + // IS NULL and IS NOT NULL tests + Seq( + // (columnName, row groups with all nulls, row groups with all non-nulls) + ("byteCol", Seq(0), Seq(11)), // int type column + ("shortCol", Seq(1), Seq(10)), // short type column + ("intCol", Seq(2), Seq(9)), // int type column + ("longCol", Seq(3), Seq(8)), // long type column + ("floatCol", Seq(4), Seq(7)), // float type column + ("doubleCol", Seq(5), Seq(6)), // double type column + ("stringCol", Seq(6), Seq(5)), // string type column + ("truncatedStringCol", Seq(7), Seq(4)), // truncatedStringCol type column + ("binaryCol", Seq(8), Seq(3)), // binary type column + ("truncatedBinaryCol", Seq(9), Seq(2)), // truncatedBinaryCol type column + ("booleanCol", Seq(10), Seq(1)), // boolean type column + ("dateCol", Seq(11), Seq(0)) // date type column + ).foreach { + // Test table has 20 row groups, each with 100 rows. + case (colName, allNullsRowGroups, allNonNullsRowGroups) => + // Test predicate on both top-level and nested columns + Seq(col(colName), col("nested", colName)).foreach { column => + val isNullFilter = isNull(column) + test(s"filter pushdown: $isNullFilter") { + val actualData = readUsingKernel(testParquetTable, isNullFilter) + val expOutputRowCount = 100 * (20 - 1) // 100 rows per row group + + // we get everything expect the rowgroup that has all non-nulls + val expRowGroups = (0 until 20).filter(!allNonNullsRowGroups.contains(_)) + assert(actualData.size === expOutputRowCount, s"predicate: $isNullFilter") + checkAnswer(actualData, generateExpData(expRowGroups)) + + // not (col is null) should return all row groups exception the one with all nulls + assertNot(isNullFilter, (0 until 20).filter(!allNullsRowGroups.contains(_))) + } + + val isNotNullFilter = isNotNull(column) + test(s"filter pushdown: $isNotNullFilter") { + val actualData = readUsingKernel(testParquetTable, isNotNullFilter) + val expOutputRowCount = 100 * (20 - 1) // 100 rows per row group + + // we get everything expect the rowgroup that has all nulls + val expRowGroups = (0 until 20).filter(!allNullsRowGroups.contains(_)) + assert(actualData.size === expOutputRowCount, s"predicate: $isNotNullFilter") + checkAnswer(actualData, generateExpData(expRowGroups)) + + // not (col is not null) should return all row groups exception the one with all non-nulls + assertNot(isNotNullFilter, (0 until 20).filter(!allNonNullsRowGroups.contains(_))) + } + } + } + test("for a column that doesn't exist in the table") { val testPredicate = predicate("=", col("nonExistentCol"), ofInt(20)) assertConvertedFilterIsEmpty(testPredicate, testParquetTable) @@ -303,11 +406,26 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite checkAnswer(actData, generateExpData(Seq(15))) } - test("not support") { - val predicate = not(eq(col("booleanCol"), ofBoolean(true))) + test("not support on gt") { + val predicate = not(gt(col("intCol"), ofInt(950))) val actData = readUsingKernel(testParquetTable, predicate) - // every odd rowgroup has true for booleanCol - checkAnswer(actData, generateExpData(Seq(1, 3, 5, 7, 9, 11, 13, 15, 17, 19))) + + // rowgroups until 9 could have values <= 950 + // rowgroup 2 has all nulls, so it won't be included in the result + val expRowGroups = Seq(0, 1, 3, 4, 5, 6, 7, 8, 9) + val expOutputRowCount = expRowGroups.length * 100 // 100 rows per row group + assert(actData.size === expOutputRowCount, s"predicate: $predicate") + + checkAnswer(actData, generateExpData(expRowGroups)) + } + + test("not support on equality") { + val predicate = not(eq(col("longCol"), ofLong(768))) + val actData = readUsingKernel(testParquetTable, predicate) + // rowgroup 3 has all nulls, so it will be included in the results as + // Parquet equality filter is not null safe + // every other group has value that is not 768 + checkAnswer(actData, generateExpData(Seq.range(0, 20))) } test("doesn't work on the repeated columns") { @@ -322,4 +440,13 @@ class ParquetReaderPredicatePushdownSuite extends AnyFunSuite checkAnswer(actResult, expResult) } + + /** Test the `not(predicate)` returns expected rowgroups */ + private def assertNot(predicate: Predicate, expRowGroups: Seq[Int]): Unit = { + val notPredicate = not(predicate) + val actualData = readUsingKernel(testParquetTable, notPredicate) + val expOutputRowCount = expRowGroups.length * 100 // 100 rows per row group + assert(actualData.size === expOutputRowCount, s"predicate: $notPredicate") + checkAnswer(actualData, generateExpData(expRowGroups)) + } }