Skip to content

Commit

Permalink
feat: improved ReadRel handling (substrait-io#194)
Browse files Browse the repository at this point in the history
* feat: builder util for equal fn
* test: additional ReadRel coverage
* fix: rel to proto NamedScan conversion skipped filter
* fix: rel to proto VirtualScan conversion skipped filter
* fix: proto expression converter should use table schema

The use of the EMPTY_TYPE caused failures when attempting to convert
filters with field reference on VirtualScans, because the field
references were seen as out-of-bounds.

* fix: allow proto to rel conversion of empty VirtualScan
  • Loading branch information
vbarua authored Nov 2, 2023
1 parent 4769d2f commit 7d478f0
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
7 changes: 6 additions & 1 deletion core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ public class SubstraitBuilder {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
private static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
private static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";

private final SimpleExtension.ExtensionCollection extensions;

Expand Down Expand Up @@ -396,6 +397,10 @@ private AggregateFunctionInvocation singleArgumentArithmeticAggregate(

// Scalar Functions

public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
return scalarFn(FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
}

public Expression.ScalarFunctionInvocation scalarFn(
String namespace, String key, Type outputType, Expression... args) {
var declaration =
Expand Down
13 changes: 9 additions & 4 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.substrait.relation;

import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -108,7 +106,12 @@ public Rel from(io.substrait.proto.Rel rel) {

private Rel newRead(ReadRel rel) {
if (rel.hasVirtualTable()) {
return newVirtualTable(rel);
var virtualTable = rel.getVirtualTable();
if (virtualTable.getValuesCount() == 0) {
return newEmptyScan(rel);
} else {
return newVirtualTable(rel);
}
} else if (rel.hasNamedTable()) {
return newNamedScan(rel);
} else if (rel.hasLocalFiles()) {
Expand Down Expand Up @@ -304,7 +307,9 @@ private FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) {

private VirtualTableScan newVirtualTable(ReadRel rel) {
var virtualTable = rel.getVirtualTable();
var converter = new ProtoExpressionConverter(lookup, extensions, EMPTY_TYPE, this);
var virtualTableSchema = newNamedStruct(rel);
var converter =
new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this);
List<Expression.StructLiteral> structLiterals = new ArrayList<>(virtualTable.getValuesCount());
for (var struct : virtualTable.getValuesList()) {
structLiterals.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ public Rel visit(NamedScan namedScan) throws RuntimeException {
.setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames()))
.setBaseSchema(namedScan.getInitialSchema().toProto(typeProtoConverter));

namedScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));

namedScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setRead(builder).build();
}
Expand Down Expand Up @@ -316,6 +318,8 @@ public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException {
.build())
.setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter));

virtualTableScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));

virtualTableScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setRead(builder).build();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.substrait.type.proto;

import io.substrait.TestBase;
import io.substrait.expression.ExpressionCreator;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.NamedScan;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;

public class ReadRelRoundtripTest extends TestBase {

@Test
void namedScan() {
var tableName = Stream.of("a_table").collect(Collectors.toList());
var columnNames = Stream.of("column1", "column2").collect(Collectors.toList());
List<Type> columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList());

var namedScan = b.namedScan(tableName, columnNames, columnTypes);
namedScan =
NamedScan.builder()
.from(namedScan)
.filter(b.equal(b.fieldReference(namedScan, 0), b.fieldReference(namedScan, 1)))
.build();

verifyRoundTrip(namedScan);
}

@Test
void emptyScan() {
var emptyScan =
EmptyScan.builder()
.initialSchema(NamedStruct.of(Collections.emptyList(), R.struct()))
.build();
verifyRoundTrip(emptyScan);
}

@Test
void virtualTable() {
var virtTable =
VirtualTableScan.builder()
.addAllDfsNames(Stream.of("column1", "column2").collect(Collectors.toList()))
.addRows(
ExpressionCreator.struct(
false, ExpressionCreator.i64(false, 1), ExpressionCreator.i64(false, 2)))
.build();
virtTable =
VirtualTableScan.builder()
.from(virtTable)
.filter(b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1)))
.build();
verifyRoundTrip(virtTable);
}
}

0 comments on commit 7d478f0

Please # to comment.