Skip to content

Commit

Permalink
[SPARK-39385][SQL] Supports push down REGR_AVGX and REGR_AVGY
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
#36773 translate linear regression aggregate functions for pushdown.
Although `REGR_AVGX` and `REGR_AVGY` are replaced to `AVG` in runtime, we can pushdown `AVG` to achieve the same result that push down `REGR_AVGX` and `REGR_AVGY`.

Take `RegrAvgX` as an example, `RegrAvgX` replaced with `Average(If(And(IsNotNull(left), IsNotNull(right)), right, Literal.create(null, right.dataType)))` in runtime and then the latter will be optimized as `Average(CaseWhen(Seq[(And(IsNotNull(left), IsNotNull(right)), right)], Some(Literal.create(null, right.dataType))))`

We can see `Literal.create(null, right.dataType)` here, `visitLiteral` of `JDBCSQLBuilder` cannot processing the null literal in the correct way. So we need to fix the issue too.

### Why are the changes needed?
Let Aggregate pushdown supports `REGR_AVGX` and `REGR_AVGY`.

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New test cases.

Closes #37126 from beliefer/SPARK-39385_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
beliefer authored and cloud-fan committed Jul 13, 2022
1 parent 0b1077c commit a79c91e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ abstract class JdbcDialect extends Serializable with Logging {

private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
override def visitLiteral(literal: Literal[_]): String = {
compileValue(
CatalystTypeConverters.convertToScala(literal.value(), literal.dataType())).toString
Option(literal.value()).map(v =>
compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString)
.getOrElse(super.visitLiteral(literal))
}

override def visitNamedReference(namedRef: NamedReference): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,38 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], ReadSchema:")
checkAnswer(df2,
Seq(Row(0.0, 1.0, 1.0, 20000.0), Row(0.0, 1.0, 1.0, 5000.0), Row(null, null, null, 0.0)))

val df3 = sql(
"""
|SELECT
| REGR_AVGX(bonus, bonus),
| REGR_AVGY(bonus, bonus)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df3)
checkAggregateRemoved(df3)
checkPushedInfo(df3,
"""
|PushedAggregates: [AVG(CASE WHEN BONUS IS NOT NULL THEN BONUS ELSE null END)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df3, Seq(Row(1100.0, 1100.0), Row(1200.0, 1200.0), Row(1250.0, 1250.0)))

val df4 = sql(
"""
|SELECT
| REGR_AVGX(DISTINCT bonus, bonus),
| REGR_AVGY(DISTINCT bonus, bonus)
|FROM h2.test.employee WHERE dept > 0 GROUP BY DePt""".stripMargin)
checkFiltersRemoved(df4)
checkAggregateRemoved(df4)
checkPushedInfo(df4,
"""
|PushedAggregates: [AVG(DISTINCT CASE WHEN BONUS IS NOT NULL THEN BONUS ELSE null END)],
|PushedFilters: [DEPT IS NOT NULL, DEPT > 0],
|PushedGroupByExpressions: [DEPT],
|""".stripMargin.replaceAll("\n", " "))
checkAnswer(df4, Seq(Row(1100.0, 1100.0), Row(1200.0, 1200.0), Row(1250.0, 1250.0)))
}

test("scan with aggregate push-down: aggregate over alias push down") {
Expand Down

0 comments on commit a79c91e

Please # to comment.