diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index d3d3511fb..a1e5e3fba 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -53,9 +53,17 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) { } // Relations + public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) { + return Aggregate.Measure.builder().function(aggFn).build(); + } + + public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) { + return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build(); + } + public Aggregate aggregate( Function groupingFn, - Function> measuresFn, + Function> measuresFn, Rel input) { Function> groupingsFn = groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList())); @@ -64,7 +72,7 @@ public Aggregate aggregate( public Aggregate aggregate( Function groupingFn, - Function> measuresFn, + Function> measuresFn, Rel.Remap remap, Rel input) { Function> groupingsFn = @@ -74,14 +82,11 @@ public Aggregate aggregate( private Aggregate aggregate( Function> groupingsFn, - Function> measuresFn, + Function> measuresFn, Optional remap, Rel input) { var groupings = groupingsFn.apply(input); - var measures = - measuresFn.apply(input).stream() - .map(m -> Aggregate.Measure.builder().function(m).build()) - .collect(java.util.stream.Collectors.toList()); + var measures = measuresFn.apply(input); return Aggregate.builder() .groupings(groupings) .measures(measures) @@ -389,6 +394,11 @@ public List sortFields(Rel input, int... indexes) { .collect(java.util.stream.Collectors.toList()); } + public Expression.SortField sortField( + Expression expression, Expression.SortDirection sortDirection) { + return Expression.SortField.builder().expr(expression).direction(sortDirection).build(); + } + public SwitchClause switchClause(Expression.Literal condition, Expression then) { return SwitchClause.builder().condition(condition).then(then).build(); } @@ -422,76 +432,150 @@ public Aggregate.Grouping grouping(Rel input, int... indexes) { return Aggregate.Grouping.builder().addAllExpressions(columns).build(); } - public AggregateFunctionInvocation count(Rel input, int field) { + public Aggregate.Grouping grouping(Expression... expressions) { + return Aggregate.Grouping.builder().addExpressions(expressions).build(); + } + + public Aggregate.Measure count(Rel input, int field) { var declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any")); - return AggregateFunctionInvocation.builder() - .arguments(fieldReferences(input, field)) - .outputType(R.I64) - .declaration(declaration) - .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) - .invocation(Expression.AggregationInvocation.ALL) - .build(); + return measure( + AggregateFunctionInvocation.builder() + .arguments(fieldReferences(input, field)) + .outputType(R.I64) + .declaration(declaration) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()); + } + + public Aggregate.Measure min(Rel input, int field) { + return min(fieldReference(input, field)); } - public AggregateFunctionInvocation min(Rel input, int field) { - Type inputType = input.getRecordType().fields().get(field); - // min output is always nullable + public Aggregate.Measure min(Expression expr) { return singleArgumentArithmeticAggregate( - input, field, "min", TypeCreator.asNullable(inputType)); + expr, + "min", + // min output is always nullable + TypeCreator.asNullable(expr.getType())); } - public AggregateFunctionInvocation max(Rel input, int field) { - Type inputType = input.getRecordType().fields().get(field); - // max output is always nullable + public Aggregate.Measure max(Rel input, int field) { + return max(fieldReference(input, field)); + } + + public Aggregate.Measure max(Expression expr) { return singleArgumentArithmeticAggregate( - input, field, "max", TypeCreator.asNullable(inputType)); + expr, + "max", + // max output is always nullable + TypeCreator.asNullable(expr.getType())); } - public AggregateFunctionInvocation avg(Rel input, int field) { - Type inputType = input.getRecordType().fields().get(field); - // avg output is always nullable + public Aggregate.Measure avg(Rel input, int field) { + return avg(fieldReference(input, field)); + } + + public Aggregate.Measure avg(Expression expr) { return singleArgumentArithmeticAggregate( - input, field, "avg", TypeCreator.asNullable(inputType)); + expr, + "avg", + // avg output is always nullable + TypeCreator.asNullable(expr.getType())); + } + + public Aggregate.Measure sum(Rel input, int field) { + return sum(fieldReference(input, field)); } - public AggregateFunctionInvocation sum(Rel input, int field) { - Type inputType = input.getRecordType().fields().get(field); - // sum output is always nullable + public Aggregate.Measure sum(Expression expr) { return singleArgumentArithmeticAggregate( - input, field, "sum", TypeCreator.asNullable(inputType)); + expr, + "sum", + // sum output is always nullable + TypeCreator.asNullable(expr.getType())); } - public AggregateFunctionInvocation sum0(Rel input, int field) { - // sum0 output is always NOT NULL I64 - return singleArgumentArithmeticAggregate(input, field, "sum0", R.I64); + public Aggregate.Measure sum0(Rel input, int field) { + return sum(fieldReference(input, field)); } - private AggregateFunctionInvocation singleArgumentArithmeticAggregate( - Rel input, int field, String functionName, Type outputType) { - Type inputType = input.getRecordType().fields().get(field); - String typeString = inputType.accept(ToTypeString.INSTANCE); + public Aggregate.Measure sum0(Expression expr) { + return singleArgumentArithmeticAggregate( + expr, + "sum0", + // sum0 output is always NOT NULL I64 + R.I64); + } + + private Aggregate.Measure singleArgumentArithmeticAggregate( + Expression expr, String functionName, Type outputType) { + String typeString = ToTypeString.apply(expr.getType()); var declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("%s:%s", functionName, typeString))); - return AggregateFunctionInvocation.builder() - .arguments(fieldReferences(input, field)) - .outputType(outputType) - .declaration(declaration) - // INITIAL_TO_RESULT is the most restrictive aggregation phase type, - // as it does not allow decomposition. Use it as the default for now. - // TODO: set this per function - .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) - .invocation(Expression.AggregationInvocation.ALL) - .build(); + return measure( + AggregateFunctionInvocation.builder() + .arguments(Arrays.asList(expr)) + .outputType(outputType) + .declaration(declaration) + // INITIAL_TO_RESULT is the most restrictive aggregation phase type, + // as it does not allow decomposition. Use it as the default for now. + // TODO: set this per function + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()); } // Scalar Functions + public Expression.ScalarFunctionInvocation negate(Expression expr) { + // output type of negate is the same as the input type + var outputType = expr.getType(); + return scalarFn( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("negate:%s", ToTypeString.apply(outputType)), + outputType, + expr); + } + + public Expression.ScalarFunctionInvocation add(Expression left, Expression right) { + return arithmeticFunction("add", left, right); + } + + public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) { + return arithmeticFunction("substract", left, right); + } + + public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) { + return arithmeticFunction("multiply", left, right); + } + + public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) { + return arithmeticFunction("divide", left, right); + } + + private Expression.ScalarFunctionInvocation arithmeticFunction( + String fname, Expression left, Expression right) { + var leftTypeStr = ToTypeString.apply(left.getType()); + var rightTypeStr = ToTypeString.apply(right.getType()); + var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); + + var isOutputNullable = left.getType().nullable() || right.getType().nullable(); + var outputType = left.getType(); + outputType = + isOutputNullable + ? TypeCreator.asNullable(outputType) + : TypeCreator.asNotNullable(outputType); + + return scalarFn(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, key, outputType, left, right); + } + public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) { return scalarFn( DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right); diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index 11e7a4eb7..1b7138317 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -5,7 +5,11 @@ public class ToTypeString extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor { - public static ToTypeString INSTANCE = new ToTypeString(); + public static final ToTypeString INSTANCE = new ToTypeString(); + + public static String apply(Type type) { + return type.accept(INSTANCE); + } private ToTypeString() { super("Only type literals and parameterized types can be used in functions."); diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java new file mode 100644 index 000000000..3b22a2e87 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -0,0 +1,227 @@ +package io.substrait.isthmus; + +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Project; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Not all Substrait {@link Aggregate} rels are convertable to {@link + * org.apache.calcite.rel.core.Aggregate} rels + * + *

The code in this class can: + * + *

    + *
  • Check for these cases + *
  • Rewrite the Substrait {@link Aggregate} such that it can be converted to Calcite + *
+ */ +public class PreCalciteAggregateValidator { + + /** + * Checks that the given {@link Aggregate} is valid for use in Calcite + * + * @param aggregate + * @return + */ + public static boolean isValidCalciteAggregate(Aggregate aggregate) { + return aggregate.getMeasures().stream() + .allMatch(PreCalciteAggregateValidator::isValidCalciteMeasure) + && aggregate.getGroupings().stream() + .allMatch(PreCalciteAggregateValidator::isValidCalciteGrouping); + } + + /** + * Checks that all expressions present in the given {@link Aggregate.Measure} are {@link + * FieldReference}s, as Calcite expects all expressions in {@link + * org.apache.calcite.rel.core.Aggregate}s to be field references. + * + * @return true if the {@code measure} can be converted to a Calcite equivalent without changes, + * false otherwise. + */ + private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { + return + // all function arguments to measures must be field references + measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg)) + && + // all sort fields must be field references + measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr())) + && + // pre-measure filter must be a field reference + measure.getPreMeasureFilter().map(f -> isSimpleFieldReference(f)).orElse(true); + } + + /** + * Checks that all expressions present in the given {@link Aggregate.Grouping} are {@link + * FieldReference}s, as Calcite expects all expressions in {@link + * org.apache.calcite.rel.core.Aggregate}s to be field references. + * + *

Additionally, checks that all grouping fields are specified in ascending order. + * + * @return true if the {@code grouping} can be converted to a Calcite equivalent without changes, + * false otherwise. + */ + private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { + if (!grouping.getExpressions().stream().allMatch(e -> isSimpleFieldReference(e))) { + // all grouping expressions must be field references + return false; + } + + // Calcite stores grouping fields in an ImmutableBitSet and does not track the order of the + // grouping fields. The output record shape that Calcite generates ALWAYS has the groupings in + // ascending field order. This causes issues with Substrait in cases where the grouping fields + // in Substrait are not defined in ascending order. + + // For example, if a grouping is defined as (0, 2, 1) in Substrait, Calcite will output it as + // (0, 1, 2), which means that the Calcite output will no longer line up with the expectations + // of the Substrait plan. + List groupingFields = + grouping.getExpressions().stream() + // isSimpleFieldReference above guarantees that the expr is a FieldReference + .map(expr -> getFieldRefOffset((FieldReference) expr)) + .collect(Collectors.toList()); + + return isOrdered(groupingFields); + } + + private static boolean isSimpleFieldReference(FunctionArg e) { + return e instanceof FieldReference fr + && fr.segments().size() == 1 + && fr.segments().get(0) instanceof FieldReference.StructField; + } + + private static int getFieldRefOffset(FieldReference fr) { + return ((FieldReference.StructField) fr.segments().get(0)).offset(); + } + + private static boolean isOrdered(List list) { + for (int i = 1; i < list.size(); i++) { + if (list.get(i - 1) > list.get(i)) { + return false; + } + } + return true; + } + + public static class PreCalciteAggregateTransformer { + + // New expressions to include in the project before the aggregate + private final List newExpressions; + + // Tracks the offset of the next expression added + private int expressionOffset; + + private PreCalciteAggregateTransformer(Aggregate aggregate) { + this.newExpressions = new ArrayList<>(); + // The Substrait project output includes all input fields, followed by expressions + this.expressionOffset = aggregate.getInput().getRecordType().fields().size(); + } + + /** + * Transforms an {@link Aggregate} that cannot be handled by Calcite into an equivalent that can + * be handled by: + * + *

    + *
  • Moving all non-field references into a project before the aggregation + *
  • Adding all groupings to this project so that they are referenced in "order" + *
+ */ + public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { + var at = new PreCalciteAggregateTransformer(aggregate); + + List newMeasures = + aggregate.getMeasures().stream().map(at::updateMeasure).collect(Collectors.toList()); + List newGroupings = + aggregate.getGroupings().stream().map(at::updateGrouping).collect(Collectors.toList()); + + Project preAggregateProject = + Project.builder().input(aggregate.getInput()).expressions(at.newExpressions).build(); + + return Aggregate.builder() + .from(aggregate) + .input(preAggregateProject) + .measures(newMeasures) + .groupings(newGroupings) + .build(); + } + + private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { + AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); + + List newFunctionArgs = + oldAggregateFunctionInvocation.arguments().stream() + .map(this::projectOutNonFieldReference) + .collect(Collectors.toList()); + + List newSortFields = + oldAggregateFunctionInvocation.sort().stream() + .map( + sf -> + Expression.SortField.builder() + .from(sf) + .expr(projectOutNonFieldReference(sf.expr())) + .build()) + .collect(Collectors.toList()); + + Optional newPreMeasureFilter = + measure.getPreMeasureFilter().map(this::projectOutNonFieldReference); + + AggregateFunctionInvocation newAggregateFunctionInvocation = + AggregateFunctionInvocation.builder() + .from(oldAggregateFunctionInvocation) + .arguments(newFunctionArgs) + .sort(newSortFields) + .build(); + + return Aggregate.Measure.builder() + .function(newAggregateFunctionInvocation) + .preMeasureFilter(newPreMeasureFilter) + .build(); + } + + private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { + // project out all groupings unconditionally, even field references + // this ensures that out of order groupings are re-projected into in order groupings + List newGroupingExpressions = + grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList()); + return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build(); + } + + private Expression projectOutNonFieldReference(FunctionArg farg) { + if ((farg instanceof Expression e)) { + return projectOutNonFieldReference(e); + } else { + throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); + } + } + + private Expression projectOutNonFieldReference(Expression expr) { + if (isSimpleFieldReference(expr)) { + return expr; + } + return projectOut(expr); + } + + /** + * Adds a new expression to the project at {@link + * PreCalciteAggregateTransformer#expressionOffset} and returns a field reference to the new + * expression + */ + private Expression projectOut(Expression expr) { + newExpressions.add(expr); + return ImmutableFieldReference.builder() + // create a field reference to the new expression, then update the expression offset + .addSegments(FieldReference.StructField.of(expressionOffset++)) + .type(expr.getType()) + .build(); + } + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index fb31f8d22..8ff98e6f6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -31,6 +31,7 @@ import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; @@ -227,6 +228,12 @@ public RelNode visit(Set set) throws RuntimeException { @Override public RelNode visit(Aggregate aggregate) throws RuntimeException { + if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) { + aggregate = + PreCalciteAggregateValidator.PreCalciteAggregateTransformer + .transformToValidCalciteAggregate(aggregate); + } + RelNode child = aggregate.getInput().accept(this); var groupExprLists = aggregate.getGroupings().stream() @@ -268,8 +275,8 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { } List argIndex = new ArrayList<>(); for (RexNode arg : arguments) { - // TODO: rewrite compound expression into project Rel - checkRexInputRefOnly(arg, "argument", measure.getFunction().declaration().name()); + // arguments are guaranteed to be RexInputRef because of the prior call to + // transformToValidCalciteAggregate argIndex.add(((RexInputRef) arg).getIndex()); } @@ -292,12 +299,18 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { int filterArg = -1; if (measure.getPreMeasureFilter().isPresent()) { RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter); - // TODO: rewrite compound expression into project Rel - // Calcite's AggregateCall only allow agg filter to be a direct filter from input - checkRexInputRefOnly(filter, "filter", measure.getFunction().declaration().name()); filterArg = ((RexInputRef) filter).getIndex(); } + RelCollation relCollation = RelCollations.EMPTY; + if (!measure.getFunction().sort().isEmpty()) { + relCollation = + RelCollations.of( + measure.getFunction().sort().stream() + .map(sortField -> toRelFieldCollation(sortField)) + .collect(Collectors.toList())); + } + return AggregateCall.create( aggFunction, distinct, @@ -306,7 +319,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { argIndex, filterArg, null, - RelCollations.EMPTY, + relCollation, returnType, null); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 5dd07e7a3..ce8612b9b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -271,6 +271,8 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal if (call.filterArg != -1) { builder.preMeasureFilter(FieldReference.newRootStructReference(call.filterArg, inputType)); } + // TODO: handle the collation on the AggregateCall + // https://github.com/substrait-io/substrait-java/issues/215 return builder.build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index 657b3ee24..62357aa3d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -2,7 +2,7 @@ import com.google.common.collect.Streams; import io.substrait.dsl.SubstraitBuilder; -import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; import io.substrait.type.Type; @@ -40,7 +40,7 @@ public class AggregationFunctionsTest extends PlanTestBase { private NamedScan numericTypesTable = b.namedScan(List.of("example"), columnNames, tableTypes); // Create the given function call on the given field of the input - private AggregateFunctionInvocation functionPicker(Rel input, int field, String fname) { + private Aggregate.Measure functionPicker(Rel input, int field, String fname) { return switch (fname) { case "min" -> b.min(input, field); case "max" -> b.max(input, field); @@ -53,7 +53,7 @@ private AggregateFunctionInvocation functionPicker(Rel input, int field, String } // Create one function call per numeric type column - private List functions(Rel input, String fname) { + private List functions(Rel input, String fname) { // first column is for grouping, skip it return IntStream.range(1, tableTypes.size()) .boxed() diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java new file mode 100644 index 000000000..efa3d00e3 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -0,0 +1,193 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.Expression; +import io.substrait.relation.Aggregate; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class ComplexAggregateTest extends PlanTestBase { + + final TypeCreator R = TypeCreator.of(false); + SubstraitBuilder b = new SubstraitBuilder(extensions); + + Aggregate.Measure withPreMeasureFilter(Aggregate.Measure measure, Expression preMeasureFilter) { + return Aggregate.Measure.builder().from(measure).preMeasureFilter(preMeasureFilter).build(); + } + + Aggregate.Measure withSort(Aggregate.Measure measure, List sortFields) { + var afi = + AggregateFunctionInvocation.builder().from(measure.getFunction()).sort(sortFields).build(); + return Aggregate.Measure.builder().from(measure).function(afi).build(); + } + + /** + * Check that: + * + *
    + *
  1. The {@code pojo} pojo given is transformed as expected by {@link + * PreCalciteAggregateValidator.PreCalciteAggregateTransformer#transformToValidCalciteAggregate} + *
  2. The {@code} (original) pojo can be converted to Calcite without issues + *
+ * + * @param pojo a pojo that requires transformation for use in Calcite + * @param expectedTransform the expected transformation output + */ + protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTransform) { + var converterPojo = + PreCalciteAggregateValidator.PreCalciteAggregateTransformer + .transformToValidCalciteAggregate(pojo); + assertEquals(expectedTransform, converterPojo); + + // Substrait POJO -> Calcite + new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo); + } + + private List columnTypes = List.of(R.I32, R.I32, R.I32, R.I32); + private List columnNames = List.of("a", "b", "c", "d"); + private NamedScan table = b.namedScan(List.of("example"), columnNames, columnTypes); + + private Aggregate.Grouping emptyGrouping = Aggregate.Grouping.builder().build(); + + @Test + void handleComplexMeasureArgument() { + // SELECT sum(c + 7) FROM example + var rel = + b.aggregate( + input -> emptyGrouping, + input -> List.of(b.sum(b.add(b.fieldReference(input, 2), b.i32(7)))), + table); + + var expectedFinal = + b.aggregate( + input -> emptyGrouping, + // sum call references input field + input -> List.of(b.sum(input, 4)), + b.project( + // add call is moved to child project + input -> List.of(b.add(b.fieldReference(input, 2), b.i32(7))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexPreMeasureFilter() { + // SELECT sum(a) FILTER (b = 42) FROM example + var rel = + b.aggregate( + input -> emptyGrouping, + input -> + List.of( + withPreMeasureFilter( + b.sum(input, 0), b.equal(b.fieldReference(input, 1), b.i32(42)))), + table); + + var expectedFinal = + b.aggregate( + input -> emptyGrouping, + input -> List.of(withPreMeasureFilter(b.sum(input, 0), b.fieldReference(input, 4))), + b.project(input -> List.of(b.equal(b.fieldReference(input, 1), b.i32(42))), table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexSortingArguments() { + // SELECT sum(d ORDER BY -b ASC) FROM example + var rel = + b.aggregate( + input -> emptyGrouping, + input -> + List.of( + withSort( + b.sum(input, 3), + List.of( + b.sortField( + b.negate(b.fieldReference(input, 1)), + Expression.SortDirection.ASC_NULLS_FIRST)))), + table); + + var expectedFinal = + b.aggregate( + input -> emptyGrouping, + input -> + List.of( + withSort( + b.sum(input, 3), + List.of( + b.sortField( + b.fieldReference(input, 4), + Expression.SortDirection.ASC_NULLS_FIRST)))), + b.project( + // negate call is moved to child project + input -> List.of(b.negate(b.fieldReference(input, 1))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleComplexGroupingArgument() { + var rel = + b.aggregate( + input -> + b.grouping( + b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + input -> List.of(), + table); + + var expectedFinal = + b.aggregate( + // grouping exprs are now field references to input + input -> b.grouping(input, 4, 5), + input -> List.of(), + b.project( + input -> + List.of( + b.fieldReference(input, 2), b.add(b.fieldReference(input, 1), b.i32(42))), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void handleOutOfOrderGroupingArguments() { + var rel = b.aggregate(input -> b.grouping(input, 1, 0, 2), input -> List.of(), table); + + var expectedFinal = + b.aggregate( + // grouping exprs are now field references to input + input -> b.grouping(input, 4, 5, 6), + input -> List.of(), + b.project( + // ALL grouping exprs are added to the child projects (including field references) + input -> + List.of( + b.fieldReference(input, 1), + b.fieldReference(input, 0), + b.fieldReference(input, 2)), + table)); + + validateAggregateTransformation(rel, expectedFinal); + } + + @Test + void outOfOrderGroupingKeysHaveCorrectCalciteType() { + Rel rel = + b.aggregate( + input -> b.grouping(input, 2, 0), + input -> List.of(), + b.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); + var relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); + assertRowMatch(relNode.getRowType(), R.STRING, R.I64); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index d7e18bfb2..6c18ded92 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -201,8 +201,9 @@ void customAggregateFunctionRoundtrip() { input -> b.grouping(input, 0), input -> List.of( - b.aggregateFn( - NAMESPACE, "custom_aggregate:i64", R.I64, b.fieldReference(input, 0))), + b.measure( + b.aggregateFn( + NAMESPACE, "custom_aggregate:i64", R.I64, b.fieldReference(input, 0)))), b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); RelNode calciteRel = substraitToCalcite.convert(rel); diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index c55386df3..c0b725645 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -2,6 +2,7 @@ import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import com.google.common.annotations.Beta; @@ -15,12 +16,14 @@ import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; +import io.substrait.type.Type; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.SqlKind; @@ -203,4 +206,15 @@ protected void assertFullRoundTrip(Rel pojo1) { // Verify that POJOs are the same assertEquals(pojo1, pojo3); } + + protected void assertRowMatch(RelDataType actual, Type... expected) { + assertRowMatch(actual, Arrays.asList(expected)); + } + + protected void assertRowMatch(RelDataType actual, List expected) { + Type type = TypeConverter.DEFAULT.toSubstrait(actual); + assertInstanceOf(Type.Struct.class, type); + Type.Struct struct = (Type.Struct) type; + assertEquals(expected, struct.fields()); + } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java index 1c90c0bea..37753565f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -1,8 +1,5 @@ package io.substrait.isthmus; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - import io.substrait.dsl.SubstraitBuilder; import io.substrait.plan.Plan; import io.substrait.relation.Join.JoinType; @@ -10,11 +7,9 @@ import io.substrait.relation.Set.SetOp; import io.substrait.type.Type; import io.substrait.type.TypeCreator; -import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; -import org.apache.calcite.rel.type.RelDataType; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -35,17 +30,6 @@ public class SubstraitRelNodeConverterTest extends PlanTestBase { final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); - void assertRowMatch(RelDataType actual, Type... expected) { - assertRowMatch(actual, Arrays.asList(expected)); - } - - void assertRowMatch(RelDataType actual, List expected) { - Type type = TypeConverter.DEFAULT.toSubstrait(actual); - assertInstanceOf(Type.Struct.class, type); - Type.Struct struct = (Type.Struct) type; - assertEquals(expected, struct.fields()); - } - @Nested class Aggregate { @Test