From 12abfe79173f6ab00b3341f3b31cad5aa26aa6e4 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Sun, 18 Apr 2021 18:03:50 +0300 Subject: [PATCH] [SPARK-34716][SQL] Support ANSI SQL intervals by the aggregate function `sum` ### What changes were proposed in this pull request? Extend the `Sum` expression to to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` to support `DayTimeIntervalType` and `YearMonthIntervalType`. ### Does this PR introduce _any_ user-facing change? 'No'. Should not since new types have not been released yet. ### How was this patch tested? Jenkins test Closes #32107 from beliefer/SPARK-34716. Lead-authored-by: gengjiaan Co-authored-by: beliefer Co-authored-by: Hyukjin Kwon Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../catalyst/expressions/aggregate/Sum.scala | 14 +++++-- .../ExpressionTypeCheckingSuite.scala | 2 +- .../vectorized/OnHeapColumnVector.java | 4 +- .../spark/sql/execution/aggregate/udaf.scala | 24 +++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 41 +++++++++++++++++++ 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4dc5ce1de047b..0c6685d76fd04 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -90,7 +90,9 @@ public static int calculateBitSetWidthInBytes(int numFields) { FloatType, DoubleType, DateType, - TimestampType + TimestampType, + YearMonthIntervalType, + DayTimeIntervalType }))); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 56eebedddf08d..8ea687d78aaee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -46,15 +45,22 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess + case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess + case other => TypeCheckResult.TypeCheckFailure( + s"function sum requires numeric or interval types, not ${other.catalogString}") + } private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) case _: IntegralType => LongType + case _: YearMonthIntervalType => YearMonthIntervalType + case _: DayTimeIntervalType => DayTimeIntervalType case _ => DoubleType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 44f333342d1c8..1b9135eef69f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -158,7 +158,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min(Symbol("mapField")), "min does not support ordering on type") assertError(Max(Symbol("mapField")), "max does not support ordering on type") - assertError(Sum(Symbol("booleanField")), "function sum requires numeric type") + assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types") assertError(Average(Symbol("booleanField")), "function average requires numeric type") } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 5a7d6cc20971b..5942c5f00a710 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -541,14 +541,14 @@ protected void reserveInternal(int newCapacity) { shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || - DecimalType.is32BitDecimalType(type)) { + DecimalType.is32BitDecimalType(type) || type instanceof YearMonthIntervalType) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || - DecimalType.is64BitDecimalType(type)) { + DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 1aae76e0fb29b..33cff7ff2b801 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -87,6 +87,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case YearMonthIntervalType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case DayTimeIntervalType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -187,6 +195,22 @@ sealed trait BufferSetterGetterUtils { row.setNullAt(ordinal) } + case YearMonthIntervalType => + (row: InternalRow, ordinal: Int, value: Any) => + if (value != null) { + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case DayTimeIntervalType => + (row: InternalRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) + } else { + row.setNullAt(ordinal) + } + case other => (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3e137d49e64c3..92d3dc6fb88ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql +import java.time.{Duration, Period} + import scala.util.Random import org.scalatest.matchers.must.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1110,6 +1113,44 @@ class DataFrameAggregateSuite extends QueryTest val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) assert(e.message.contains("requires integral type")) } + + test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { + val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)), + (2, Period.ofMonths(1), Duration.ofDays(1)), + (2, null, null), + (3, Period.ofMonths(-3), Duration.ofDays(-6)), + (3, Period.ofMonths(21), Duration.ofDays(-5))) + .toDF("class", "year-month", "day-time") + + val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)), + (Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day-time") + + val sumDF = df.select(sum($"year-month"), sum($"day-time")) + checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0))) + assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + assert(sumDF.schema == StructType(Seq(StructField("sum(year-month)", YearMonthIntervalType), + StructField("sum(day-time)", DayTimeIntervalType)))) + + val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) + checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: + Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: + Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) + assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), + StructField("sum(year-month)", YearMonthIntervalType), + StructField("sum(day-time)", DayTimeIntervalType)))) + + val error = intercept[SparkException] { + checkAnswer(df2.select(sum($"year-month")), Nil) + } + assert(error.toString contains "java.lang.ArithmeticException: integer overflow") + + val error2 = intercept[SparkException] { + checkAnswer(df2.select(sum($"day-time")), Nil) + } + assert(error2.toString contains "java.lang.ArithmeticException: long overflow") + } } case class B(c: Option[Double])