Skip to content

Commit

Permalink
fix: account for struct fields in VirtualTableScan check
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed May 3, 2024
1 parent a49de62 commit 0097759
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 1 deletion.
140 changes: 139 additions & 1 deletion core/src/main/java/io/substrait/relation/VirtualTableScan.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.expression.Expression;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import java.util.List;
import org.immutables.value.Value;

Expand Down Expand Up @@ -30,7 +31,9 @@ protected void check() {

assert rows.size() > 0
&& names.stream().noneMatch(s -> s == null)
&& rows.stream().noneMatch(r -> r == null || r.fields().size() != names.size());
&& rows.stream().noneMatch(r -> r == null)
&& rows.stream()
.allMatch(r -> r.getType().accept(new NamedFieldCountingTypeVisitor()) == names.size());
}

@Override
Expand All @@ -46,4 +49,139 @@ public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
public static ImmutableVirtualTableScan.Builder builder() {
return ImmutableVirtualTableScan.builder();
}

private static class NamedFieldCountingTypeVisitor
implements TypeVisitor<Integer, RuntimeException> {
@Override
public Integer visit(Type.Bool type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.I8 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.I16 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.I32 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.I64 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.FP32 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.FP64 type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Str type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Binary type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Date type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Time type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.TimestampTZ type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Timestamp type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.PrecisionTimestamp type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.PrecisionTimestampTZ type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.IntervalYear type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.IntervalDay type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.UUID type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.FixedChar type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.VarChar type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.FixedBinary type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Decimal type) throws RuntimeException {
return 0;
}

@Override
public Integer visit(Type.Struct type) throws RuntimeException {
// Only struct fields have names - the top level column names are also
// captured by this since the whole schema is wrapped in a Struct type
return type.fields().stream().mapToInt(field -> 1 + field.accept(this)).sum();
}

@Override
public Integer visit(Type.ListType type) throws RuntimeException {
return type.elementType().accept(this);
}

@Override
public Integer visit(Type.Map type) throws RuntimeException {
return type.key().accept(this) + type.value().accept(this);
}

@Override
public Integer visit(Type.UserDefined type) throws RuntimeException {
return 0;
}
}
}
52 changes: 52 additions & 0 deletions core/src/test/java/io/substrait/relation/VirtualTableScanTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.substrait.relation;

import static io.substrait.expression.ExpressionCreator.*;
import static org.junit.jupiter.api.Assertions.*;

import io.substrait.expression.Expression;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;

class VirtualTableScanTest {

@Test
void check() {
VirtualTableScan virtualTableScan =
ImmutableVirtualTableScan.builder()
.addDfsNames(
"string",
"struct",
"struct_field1",
"struct_field2",
"list",
"list_struct_field1",
"map",
"map_key_struct_field1",
"map_value_struct_field1")
.addRows(
struct(
false,
string(false, "string_val"),
struct(
false,
string(false, "struct_field1_val"),
string(false, "struct_field2_val")),
list(false, struct(false, string(false, "list_struct_field1_val"))),
map(
false,
mapOf(
struct(false, string(false, "map_key_struct_field1_val")),
struct(false, string(false, "map_value_struct_field1_val"))))))
.build();
assertDoesNotThrow(virtualTableScan::check);
}

private Map<Expression.Literal, Expression.Literal> mapOf(
Expression.Literal key, Expression.Literal value) {
// Map.of() comes only in Java 9 and the "core" module is on Java 8
HashMap<Expression.Literal, Expression.Literal> map = new HashMap<>();
map.put(key, value);
return map;
}
}

0 comments on commit 0097759

Please # to comment.