Skip to content

[MINOR]: Add new test for filter pushdown into cross join #8648

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl EliminateCrossJoin {
/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
/// 'select ... from a, b where a.x > b.y'
/// For above queries, the join predicate is available in filters and they are moved to
/// join nodes appropriately
/// This fix helps to improve the performance of TPCH Q19. issue#78
Expand Down
12 changes: 9 additions & 3 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,11 @@ impl PushDownFilter {
}
}

/// Convert cross join to join by pushing down filter predicate to the join condition
/// Converts the given cross join to an inner join with an empty equality
/// predicate and an empty filter condition.
fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
let CrossJoin { left, right, .. } = cross_join;
let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
// predicate is given
Ok(Join {
left,
right,
Expand All @@ -982,7 +982,8 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
})
}

/// Converts the inner join with empty equality predicate and empty filter condition to the cross join
/// Converts the given inner join with an empty equality predicate and an
/// empty filter condition to a cross join.
fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result<LogicalPlan> {
if let LogicalPlan::Join(join) = &plan {
// Can be converted back to cross join
Expand All @@ -991,6 +992,11 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result<LogicalPlan>
.cross_join(join.right.as_ref().clone())?
.build();
}
} else if let LogicalPlan::Filter(filter) = &plan {
let new_input =
convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?;
return Filter::try_new(filter.predicate.clone(), Arc::new(new_input))
.map(LogicalPlan::Filter);
}
Ok(plan)
}
Expand Down
61 changes: 46 additions & 15 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@
// specific language governing permissions and limitations
// under the License.

use async_trait::async_trait;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;

use arrow::array::{
ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
StringArray, TimestampNanosecondArray,
};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion::{
arrow::{
array::{
BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray,
StringArray, TimestampNanosecondArray,
},
datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
record_batch::RecordBatch,
},
catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider},
datasource::{MemTable, TableProvider, TableType},
prelude::{CsvReadOptions, SessionContext},
};
use datafusion_common::cast::as_float64_array;
use datafusion_common::DataFusionError;

use async_trait::async_trait;
use log::info;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::Arc;
use tempfile::TempDir;

/// Context for running tests
Expand Down Expand Up @@ -102,6 +104,8 @@ impl TestContext {
}
"joins.slt" => {
info!("Registering partition table tables");
let example_udf = create_example_udf();
test_ctx.ctx.register_udf(example_udf);
register_partition_table(&mut test_ctx).await;
}
"metadata.slt" => {
Expand Down Expand Up @@ -348,3 +352,30 @@ pub async fn register_metadata_tables(ctx: &SessionContext) {

ctx.register_batch("table_with_metadata", batch).unwrap();
}

/// Create a UDF function named "example". See the `sample_udf.rs` example
/// file for an explanation of the API.
fn create_example_udf() -> ScalarUDF {
let adder = make_scalar_function(|args: &[ArrayRef]| {
let lhs = as_float64_array(&args[0]).expect("cast failed");
let rhs = as_float64_array(&args[1]).expect("cast failed");
let array = lhs
.iter()
.zip(rhs.iter())
.map(|(lhs, rhs)| match (lhs, rhs) {
(Some(lhs), Some(rhs)) => Some(lhs + rhs),
_ => None,
})
.collect::<Float64Array>();
Ok(Arc::new(array) as ArrayRef)
});
create_udf(
"example",
// Expects two f64 values:
vec![DataType::Float64, DataType::Float64],
// Returns an f64 value:
Arc::new(DataType::Float64),
Volatility::Immutable,
adder,
)
}
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3483,6 +3483,28 @@ NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1
----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true
--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true

# Currently datafusion cannot pushdown filter conditions with scalar UDF into
# cross join.
query TT
EXPLAIN SELECT *
FROM annotated_data as t1, annotated_data as t2
WHERE EXAMPLE(t1.a, t2.a) > 3
----
logical_plan
Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3)
--CrossJoin:
----SubqueryAlias: t1
------TableScan: annotated_data projection=[a0, a, b, c, d]
----SubqueryAlias: t2
------TableScan: annotated_data projection=[a0, a, b, c, d]
physical_plan
CoalesceBatchesExec: target_batch_size=2
--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3
----CrossJoinExec
------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true
------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true

####
# Config teardown
####
Expand Down