Skip to content

Commit 2762754

Browse files
authored
Support Substrait's VirtualTables (#10531)
* Add support for Substrait VirtualTables Adds support for Substrait's VirtualTables, ie. tables with data baked-in into the Substrait plan instead of being read from a source. Adds conversion in both ways (Substrait -> DataFusion and DataFusion -> Substrait) and a roundtrip test. * fix clippy * Add support for empty relations * Fix consuming Structs inside Lists and Structs Also adds roundtrip schema assertions for cases where possible * Rename from_substrait_struct -> from_substrait_struct_type for clarity * Add DataType::LargeList to to_substrait_named_struct * cargo fmt --all * Add validation that names list matches schema exactly * Add a LargeList into VALUES test
1 parent 5a9712e commit 2762754

File tree

3 files changed

+361
-62
lines changed

3 files changed

+361
-62
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 157 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
// under the License.
1717

1818
use async_recursion::async_recursion;
19-
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
19+
use datafusion::arrow::datatypes::{
20+
DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
21+
};
2022
use datafusion::common::{
2123
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2224
};
2325

2426
use datafusion::execution::FunctionRegistry;
2527
use datafusion::logical_expr::{
26-
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, LogicalPlan,
27-
Operator, ScalarUDF,
28+
aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr,
29+
LogicalPlan, Operator, ScalarUDF, Values,
2830
};
2931
use datafusion::logical_expr::{
3032
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
@@ -58,7 +60,7 @@ use substrait::proto::{
5860
rel::RelType,
5961
set_rel,
6062
sort_field::{SortDirection, SortKind::*},
61-
AggregateFunction, Expression, Plan, Rel, Type,
63+
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
6264
};
6365
use substrait::proto::{FunctionArgument, SortField};
6466

@@ -509,7 +511,51 @@ pub async fn from_substrait_rel(
509511
_ => Ok(t),
510512
}
511513
}
512-
_ => not_impl_err!("Only NamedTable reads are supported"),
514+
Some(ReadType::VirtualTable(vt)) => {
515+
let base_schema = read.base_schema.as_ref().ok_or_else(|| {
516+
substrait_datafusion_err!("No base schema provided for Virtual Table")
517+
})?;
518+
519+
let schema = from_substrait_named_struct(base_schema)?;
520+
521+
if vt.values.is_empty() {
522+
return Ok(LogicalPlan::EmptyRelation(EmptyRelation {
523+
produce_one_row: false,
524+
schema,
525+
}));
526+
}
527+
528+
let values = vt
529+
.values
530+
.iter()
531+
.map(|row| {
532+
let mut name_idx = 0;
533+
let lits = row
534+
.fields
535+
.iter()
536+
.map(|lit| {
537+
name_idx += 1; // top-level names are provided through schema
538+
Ok(Expr::Literal(from_substrait_literal(
539+
lit,
540+
&base_schema.names,
541+
&mut name_idx,
542+
)?))
543+
})
544+
.collect::<Result<_>>()?;
545+
if name_idx != base_schema.names.len() {
546+
return substrait_err!(
547+
"Names list must match exactly to nested schema, but found {} uses for {} names",
548+
name_idx,
549+
base_schema.names.len()
550+
);
551+
}
552+
Ok(lits)
553+
})
554+
.collect::<Result<_>>()?;
555+
556+
Ok(LogicalPlan::Values(Values { schema, values }))
557+
}
558+
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
513559
},
514560
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
515561
Ok(set_op) => match set_op {
@@ -948,7 +994,7 @@ pub async fn from_substrait_rex(
948994
}
949995
}
950996
Some(RexType::Literal(lit)) => {
951-
let scalar_value = from_substrait_literal(lit)?;
997+
let scalar_value = from_substrait_literal_without_names(lit)?;
952998
Ok(Arc::new(Expr::Literal(scalar_value)))
953999
}
9541000
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
@@ -964,9 +1010,9 @@ pub async fn from_substrait_rex(
9641010
.as_ref()
9651011
.clone(),
9661012
),
967-
from_substrait_type(output_type)?,
1013+
from_substrait_type_without_names(output_type)?,
9681014
)))),
969-
None => substrait_err!("Cast experssion without output type is not allowed"),
1015+
None => substrait_err!("Cast expression without output type is not allowed"),
9701016
},
9711017
Some(RexType::WindowFunction(window)) => {
9721018
let fun = match extensions.get(&window.function_reference) {
@@ -1062,7 +1108,15 @@ pub async fn from_substrait_rex(
10621108
}
10631109
}
10641110

1065-
pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
1111+
pub(crate) fn from_substrait_type_without_names(dt: &Type) -> Result<DataType> {
1112+
from_substrait_type(dt, &[], &mut 0)
1113+
}
1114+
1115+
fn from_substrait_type(
1116+
dt: &Type,
1117+
dfs_names: &[String],
1118+
name_idx: &mut usize,
1119+
) -> Result<DataType> {
10661120
match &dt.kind {
10671121
Some(s_kind) => match s_kind {
10681122
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
@@ -1142,7 +1196,7 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
11421196
substrait_datafusion_err!("List type must have inner type")
11431197
})?;
11441198
let field = Arc::new(Field::new_list_field(
1145-
from_substrait_type(inner_type)?,
1199+
from_substrait_type(inner_type, dfs_names, name_idx)?,
11461200
is_substrait_type_nullable(inner_type)?,
11471201
));
11481202
match list.type_variation_reference {
@@ -1182,24 +1236,69 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataTyp
11821236
),
11831237
}
11841238
},
1185-
r#type::Kind::Struct(s) => {
1186-
let mut fields = vec![];
1187-
for (i, f) in s.types.iter().enumerate() {
1188-
let field = Field::new(
1189-
&format!("c{i}"),
1190-
from_substrait_type(f)?,
1191-
is_substrait_type_nullable(f)?,
1192-
);
1193-
fields.push(field);
1194-
}
1195-
Ok(DataType::Struct(fields.into()))
1196-
}
1239+
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
1240+
s, dfs_names, name_idx,
1241+
)?)),
11971242
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
11981243
},
11991244
_ => not_impl_err!("`None` Substrait kind is not supported"),
12001245
}
12011246
}
12021247

1248+
fn from_substrait_struct_type(
1249+
s: &r#type::Struct,
1250+
dfs_names: &[String],
1251+
name_idx: &mut usize,
1252+
) -> Result<Fields> {
1253+
let mut fields = vec![];
1254+
for (i, f) in s.types.iter().enumerate() {
1255+
let field = Field::new(
1256+
next_struct_field_name(i, dfs_names, name_idx)?,
1257+
from_substrait_type(f, dfs_names, name_idx)?,
1258+
is_substrait_type_nullable(f)?,
1259+
);
1260+
fields.push(field);
1261+
}
1262+
Ok(fields.into())
1263+
}
1264+
1265+
fn next_struct_field_name(
1266+
i: usize,
1267+
dfs_names: &[String],
1268+
name_idx: &mut usize,
1269+
) -> Result<String> {
1270+
if dfs_names.is_empty() {
1271+
// If names are not given, create dummy names
1272+
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
1273+
Ok(format!("c{i}"))
1274+
} else {
1275+
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
1276+
substrait_datafusion_err!("Named schema must contain names for all fields")
1277+
})?;
1278+
*name_idx += 1;
1279+
Ok(name)
1280+
}
1281+
}
1282+
1283+
fn from_substrait_named_struct(base_schema: &NamedStruct) -> Result<DFSchemaRef> {
1284+
let mut name_idx = 0;
1285+
let fields = from_substrait_struct_type(
1286+
base_schema.r#struct.as_ref().ok_or_else(|| {
1287+
substrait_datafusion_err!("Named struct must contain a struct")
1288+
})?,
1289+
&base_schema.names,
1290+
&mut name_idx,
1291+
);
1292+
if name_idx != base_schema.names.len() {
1293+
return substrait_err!(
1294+
"Names list must match exactly to nested schema, but found {} uses for {} names",
1295+
name_idx,
1296+
base_schema.names.len()
1297+
);
1298+
}
1299+
Ok(DFSchemaRef::new(DFSchema::try_from(Schema::new(fields?))?))
1300+
}
1301+
12031302
fn is_substrait_type_nullable(dtype: &Type) -> Result<bool> {
12041303
fn is_nullable(nullability: i32) -> bool {
12051304
nullability != substrait::proto::r#type::Nullability::Required as i32
@@ -1277,7 +1376,15 @@ fn from_substrait_bound(
12771376
}
12781377
}
12791378

1280-
pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
1379+
pub(crate) fn from_substrait_literal_without_names(lit: &Literal) -> Result<ScalarValue> {
1380+
from_substrait_literal(lit, &vec![], &mut 0)
1381+
}
1382+
1383+
fn from_substrait_literal(
1384+
lit: &Literal,
1385+
dfs_names: &Vec<String>,
1386+
name_idx: &mut usize,
1387+
) -> Result<ScalarValue> {
12811388
let scalar_value = match &lit.literal_type {
12821389
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
12831390
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
@@ -1359,7 +1466,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
13591466
let elements = l
13601467
.values
13611468
.iter()
1362-
.map(from_substrait_literal)
1469+
.map(|el| from_substrait_literal(el, dfs_names, name_idx))
13631470
.collect::<Result<Vec<_>>>()?;
13641471
if elements.is_empty() {
13651472
return substrait_err!(
@@ -1381,7 +1488,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
13811488
}
13821489
}
13831490
Some(LiteralType::EmptyList(l)) => {
1384-
let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?;
1491+
let element_type = from_substrait_type(
1492+
l.r#type.clone().unwrap().as_ref(),
1493+
dfs_names,
1494+
name_idx,
1495+
)?;
13851496
match lit.type_variation_reference {
13861497
DEFAULT_CONTAINER_TYPE_REF => {
13871498
ScalarValue::List(ScalarValue::new_list(&[], &element_type))
@@ -1397,16 +1508,16 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
13971508
Some(LiteralType::Struct(s)) => {
13981509
let mut builder = ScalarStructBuilder::new();
13991510
for (i, field) in s.fields.iter().enumerate() {
1400-
let sv = from_substrait_literal(field)?;
1401-
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
1402-
builder = builder.with_scalar(
1403-
Field::new(&format!("c{i}"), sv.data_type(), field.nullable),
1404-
sv,
1405-
);
1511+
let name = next_struct_field_name(i, dfs_names, name_idx)?;
1512+
let sv = from_substrait_literal(field, dfs_names, name_idx)?;
1513+
builder = builder
1514+
.with_scalar(Field::new(name, sv.data_type(), field.nullable), sv);
14061515
}
14071516
builder.build()?
14081517
}
1409-
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
1518+
Some(LiteralType::Null(ntype)) => {
1519+
from_substrait_null(ntype, dfs_names, name_idx)?
1520+
}
14101521
Some(LiteralType::UserDefined(user_defined)) => {
14111522
match user_defined.type_reference {
14121523
INTERVAL_YEAR_MONTH_TYPE_REF => {
@@ -1461,7 +1572,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
14611572
Ok(scalar_value)
14621573
}
14631574

1464-
fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
1575+
fn from_substrait_null(
1576+
null_type: &Type,
1577+
dfs_names: &[String],
1578+
name_idx: &mut usize,
1579+
) -> Result<ScalarValue> {
14651580
if let Some(kind) = &null_type.kind {
14661581
match kind {
14671582
r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
@@ -1539,7 +1654,11 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
15391654
)),
15401655
r#type::Kind::List(l) => {
15411656
let field = Field::new_list_field(
1542-
from_substrait_type(l.r#type.clone().unwrap().as_ref())?,
1657+
from_substrait_type(
1658+
l.r#type.clone().unwrap().as_ref(),
1659+
dfs_names,
1660+
name_idx,
1661+
)?,
15431662
true,
15441663
);
15451664
match l.type_variation_reference {
@@ -1554,6 +1673,10 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
15541673
),
15551674
}
15561675
}
1676+
r#type::Kind::Struct(s) => {
1677+
let fields = from_substrait_struct_type(s, dfs_names, name_idx)?;
1678+
Ok(ScalarStructBuilder::new_null(fields))
1679+
}
15571680
_ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"),
15581681
}
15591682
} else {

0 commit comments

Comments
 (0)