From d7bf39d234657b22e20bba1dfd1e508f7ff3b466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Ili=C4=87?= Date: Wed, 28 Aug 2024 17:44:03 +0200 Subject: [PATCH] [3.2][Kernel] Fix binary comparator to use the unsigned comparison (#3617) Fixed binary comparator. Previously, bytes were compared as signed, which was incorrect. Tests added to `DefaultExpressionEvaluatorSuite.scala` --- .../expressions/DefaultExpressionUtils.java | 24 +++++----- .../DefaultExpressionEvaluatorSuite.scala | 48 +++++++++++++++++++ 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index 26bc0d66d20..a5e1f8bdaab 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -49,6 +49,17 @@ private DefaultExpressionUtils() {} return Integer.compare(leftBytes.length, rightBytes.length); }; + static final Comparator BINARY_COMPARTOR = (leftOp, rightOp) -> { + int i = 0; + while (i < leftOp.length && i < rightOp.length) { + if (leftOp[i] != rightOp[i]) { + return Byte.toUnsignedInt(leftOp[i]) - Byte.toUnsignedInt(rightOp[i]); + } + i++; + } + return Integer.compare(leftOp.length, rightOp.length); + }; + /** * Utility method that calculates the nullability result from given two vectors. Result is * null if at least one side is a null. @@ -218,19 +229,10 @@ static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) } static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { - Comparator comparator = (leftOp, rightOp) -> { - int i = 0; - while (i < leftOp.length && i < rightOp.length) { - if (leftOp[i] != rightOp[i]) { - return Byte.compare(leftOp[i], rightOp[i]); - } - i++; - } - return Integer.compare(leftOp.length, rightOp.length); - }; for (int rowId = 0; rowId < left.getSize(); rowId++) { if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); + result[rowId] = + BINARY_COMPARTOR.compare(left.getBinary(rowId), right.getBinary(rowId)); } } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index 8b9172b87f7..20cf23773c8 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -401,6 +401,54 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa ofBinary("apples".getBytes()), ofNull(BinaryType.BINARY) ), + ( + ofBinary(Array[Byte]()), + ofBinary(Array[Byte](5.toByte)), + ofBinary(Array[Byte]()), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](0.toByte)), // 00000000 + ofBinary(Array[Byte](-1.toByte)), // 11111111 + ofBinary(Array[Byte](0.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](127.toByte)), // 01111111 + ofBinary(Array[Byte](-1.toByte)), // 11111111 + ofBinary(Array[Byte](127.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofBinary(Array[Byte](6.toByte)), + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofBinary(Array[Byte](5.toByte, 100.toByte)), + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](5.toByte, 10.toByte, 5.toByte)), // 00000101 00001010 00000101 + ofBinary(Array[Byte](5.toByte, -3.toByte)), // 00000101 11111101 + ofBinary(Array[Byte](5.toByte, 10.toByte, 5.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](5.toByte, -25.toByte, 5.toByte)), // 00000101 11100111 00000101 + ofBinary(Array[Byte](5.toByte, -9.toByte)), // 00000101 11110111 + ofBinary(Array[Byte](5.toByte, -25.toByte, 5.toByte)), + ofNull(BinaryType.BINARY) + ), + ( + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofBinary(Array[Byte](5.toByte, 10.toByte, 0.toByte)), + ofBinary(Array[Byte](5.toByte, 10.toByte)), + ofNull(BinaryType.BINARY) + ), ( ofDecimal(BigDecimalJ.valueOf(1.12), 7, 3), ofDecimal(BigDecimalJ.valueOf(5233.232), 7, 3),