Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Add SingleOrList support to the Isthmus converter #159

Merged
merged 3 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.ImmutablePlan;
Expand Down Expand Up @@ -237,6 +238,10 @@ public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}

public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}

public FieldReference fieldReference(Rel input, int index) {
return ImmutableFieldReference.newInputRelReference(index, input);
}
Expand Down Expand Up @@ -266,6 +271,10 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
}

// Aggregate Functions

public AggregateFunctionInvocation aggregateFn(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.*;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.StringTypeVisitor;
Expand Down Expand Up @@ -140,6 +141,13 @@ public RexNode visit(Expression.TimeLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(timeString, typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(SingleOrList expr) throws RuntimeException {
var lhs = expr.condition().accept(this);
return rexBuilder.makeIn(
lhs, expr.options().stream().map(e -> e.accept(this)).collect(Collectors.toList()));
}

@Override
public RexNode visit(Expression.DateLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ public void filter() throws IOException, SqlParseException {
assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY > 10");
}

@Test
public void in() throws IOException, SqlParseException {
assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY IN (10, 20)");
}

@Test
public void joinWithMultiDDLInOneString() throws IOException, SqlParseException {
assertProtoPlanRoundrip(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ public void emit() {
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), R.I32, N.STRING);
}

@Test
public void singleOrList() {
Plan.Root root =
b.root(
b.filter(
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
commonTable));
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), commonTableType);
}
}

@Nested
Expand Down