Skip to content

Commit

Permalink
[Kernel] Change comparator expression to lazy evaluation (#2853)
Browse files Browse the repository at this point in the history
## Description
Resolves #2541

## How was this patch tested?
Existing tests
  • Loading branch information
zzl-7 authored May 30, 2024
1 parent 39e91af commit 085f117
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare;
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.evalNullability;
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;

/**
Expand Down Expand Up @@ -421,44 +420,37 @@ ColumnVector visitAlwaysFalse(AlwaysFalse alwaysFalse) {
@Override
ColumnVector visitComparator(Predicate predicate) {
PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate);

int numRows = argResults.rowCount;
boolean[] result = new boolean[numRows];
boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult);
int[] compareResult = compare(argResults.leftResult, argResults.rightResult);
switch (predicate.getName()) {
case "=":
for (int rowId = 0; rowId < numRows; rowId++) {
result[rowId] = compareResult[rowId] == 0;
}
break;
return comparatorVector(
argResults.leftResult,
argResults.rightResult,
(compareResult) -> (compareResult == 0));
case ">":
for (int rowId = 0; rowId < numRows; rowId++) {
result[rowId] = compareResult[rowId] > 0;
}
break;
return comparatorVector(
argResults.leftResult,
argResults.rightResult,
(compareResult) -> (compareResult > 0));
case ">=":
for (int rowId = 0; rowId < numRows; rowId++) {
result[rowId] = compareResult[rowId] >= 0;
}
break;
return comparatorVector(
argResults.leftResult,
argResults.rightResult,
(compareResult) -> (compareResult >= 0));
case "<":
for (int rowId = 0; rowId < numRows; rowId++) {
result[rowId] = compareResult[rowId] < 0;
}
break;
return comparatorVector(
argResults.leftResult,
argResults.rightResult,
(compareResult) -> (compareResult < 0));
case "<=":
for (int rowId = 0; rowId < numRows; rowId++) {
result[rowId] = compareResult[rowId] <= 0;
}
break;
return comparatorVector(
argResults.leftResult,
argResults.rightResult,
(compareResult) -> (compareResult <= 0));
default:
// We should never reach this based on the ExpressionVisitor
throw new IllegalStateException(
String.format("%s is not a recognized comparator", predicate.getName()));
}

return new DefaultBooleanVector(numRows, Optional.of(nullability), result);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;
import java.util.function.IntPredicate;
import java.util.stream.Collectors;

import io.delta.kernel.data.ArrayValue;
Expand All @@ -33,6 +34,20 @@
* Utility methods used by the default expression evaluator.
*/
class DefaultExpressionUtils {

static final Comparator<BigDecimal> BIGDECIMAL_COMPARATOR = Comparator.naturalOrder();
static final Comparator<String> STRING_COMPARATOR = Comparator.naturalOrder();
static final Comparator<byte[]> BINARY_COMPARTOR = (leftOp, rightOp) -> {
int i = 0;
while (i < leftOp.length && i < rightOp.length) {
if (leftOp[i] != rightOp[i]) {
return Byte.compare(leftOp[i], rightOp[i]);
}
i++;
}
return Integer.compare(leftOp.length, rightOp.length);
};

private DefaultExpressionUtils() {}

/**
Expand Down Expand Up @@ -87,138 +102,91 @@ public boolean getBoolean(int rowId) {
}

/**
* Utility method to compare the left and right according to the natural ordering
* and return an integer array where each row contains the comparison result (-1, 0, 1) for
* corresponding rows in the input vectors compared.
* Utility method to create a column vector that lazily evaluate the
* comparator ex. (ie. ==, >=, <=......) for left and right
* column vector according to the natural ordering of numbers
* <p>
* Only primitive data types are supported.
*/
static int[] compare(ColumnVector left, ColumnVector right) {
static ColumnVector comparatorVector(
ColumnVector left,
ColumnVector right,
IntPredicate booleanComparator) {
checkArgument(
left.getSize() == right.getSize(),
"Left and right operand have different vector sizes.");
DataType dataType = left.getDataType();
left.getSize() == right.getSize(),
"Left and right operand have different vector sizes.");

int numRows = left.getSize();
int[] result = new int[numRows];
DataType dataType = left.getDataType();
IntPredicate vectorValueComparator;
if (dataType instanceof BooleanType) {
compareBoolean(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)));
} else if (dataType instanceof ByteType) {
compareByte(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Byte.compare(left.getByte(rowId), right.getByte(rowId)));
} else if (dataType instanceof ShortType) {
compareShort(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Short.compare(left.getShort(rowId), right.getShort(rowId)));
} else if (dataType instanceof IntegerType || dataType instanceof DateType) {
compareInt(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Integer.compare(left.getInt(rowId), right.getInt(rowId)));
} else if (dataType instanceof LongType ||
dataType instanceof TimestampType ||
dataType instanceof TimestampNTZType) {
compareLong(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Long.compare(left.getLong(rowId), right.getLong(rowId)));
} else if (dataType instanceof FloatType) {
compareFloat(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Float.compare(left.getFloat(rowId), right.getFloat(rowId)));
} else if (dataType instanceof DoubleType) {
compareDouble(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
Double.compare(left.getDouble(rowId), right.getDouble(rowId)));
} else if (dataType instanceof DecimalType) {
compareDecimal(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
BIGDECIMAL_COMPARATOR.compare(
left.getDecimal(rowId), right.getDecimal(rowId)));
} else if (dataType instanceof StringType) {
compareString(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
STRING_COMPARATOR.compare(
left.getString(rowId), right.getString(rowId)));
} else if (dataType instanceof BinaryType) {
compareBinary(left, right, result);
vectorValueComparator = rowId -> booleanComparator.test(
BINARY_COMPARTOR.compare(
left.getBinary(rowId), right.getBinary(rowId)));
} else {
throw new UnsupportedOperationException(dataType + " can not be compared.");
}
return result;
}

static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId));
}
}
}

static void compareByte(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId));
}
}
}

static void compareShort(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId));
}
}
}

static void compareInt(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId));
}
}
}

static void compareLong(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId));
}
}
}
return new ColumnVector() {

static void compareFloat(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId));
@Override
public DataType getDataType() {
return BooleanType.BOOLEAN;
}
}
}

static void compareDouble(ColumnVector left, ColumnVector right, int[] result) {
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId));
@Override
public void close() {
Utils.closeCloseables(left, right);
}
}
}

static void compareString(ColumnVector left, ColumnVector right, int[] result) {
Comparator<String> comparator = Comparator.naturalOrder();
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId));
@Override
public int getSize() {
return left.getSize();
}
}
}

static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) {
Comparator<BigDecimal> comparator = Comparator.naturalOrder();
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId));
@Override
public boolean isNullAt(int rowId) {
return left.isNullAt(rowId) || right.isNullAt(rowId);
}
}
}

static void compareBinary(ColumnVector left, ColumnVector right, int[] result) {
Comparator<byte[]> comparator = (leftOp, rightOp) -> {
int i = 0;
while (i < leftOp.length && i < rightOp.length) {
if (leftOp[i] != rightOp[i]) {
return Byte.compare(leftOp[i], rightOp[i]);
@Override
public boolean getBoolean(int rowId) {
if (isNullAt(rowId)) {
return false;
}
i++;
return vectorValueComparator.test(rowId);
}
return Integer.compare(leftOp.length, rightOp.length);
};
for (int rowId = 0; rowId < left.getSize(); rowId++) {
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId));
}
}
}

static Expression childAt(Expression expression, int index) {
Expand Down

0 comments on commit 085f117

Please # to comment.