diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala index 8efafde9682..f20e4eb33fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala @@ -404,6 +404,11 @@ case class DeltaStatsColumnSpec( } object StatisticsCollection extends DeltaCommand { + + val ASCII_MAX_CHARACTER = '\u007F' + + val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT)) + /** * The SQL grammar already includes a `multipartIdentifierList` rule for parsing a string into a * list of multi-part identifiers. We just expose it here, with a custom parser and AstBuilder. @@ -769,28 +774,68 @@ object StatisticsCollection extends DeltaCommand { txn.commit(newAddFiles, ComputeStats(predicates)) } + def truncateMinStringAgg(prefixLen: Int)(input: String): String = { + if (input == null || input.length <= prefixLen) { + return input + } + if (prefixLen == 0) { + return null + } + if (Character.isHighSurrogate(input.charAt(prefixLen - 1)) && + Character.isLowSurrogate(input.charAt(prefixLen))) { + // If the character at prefixLen - 1 is a high surrogate and the next character is a low + // surrogate, we need to include the next character in the prefix to ensure that we don't + // truncate the string in the middle of a surrogate pair. + input.take(prefixLen + 1) + } else { + input.take(prefixLen) + } + } + /** - * Helper method to truncate the input string `x` to the given `prefixLen` length, while also - * appending the unicode max character to the end of the truncated string. This ensures that any - * value in this column is less than or equal to the max. - * Note: Input string `x` must be properly encoded in UTF-8. + * Helper method to truncate the input string `input` to the given `prefixLen` length, while also + * ensuring the any value in this column is less than or equal to the truncated max in both UTF-8 + * encoding. */ - def truncateMaxStringAgg(prefixLen: Int)(x: String): String = { - if (x == null || x.length <= prefixLen) { - x - } else { - // Grab the prefix. We want to append `\ufffd` as a tie-breaker, but that is only safe - // if the character we truncated was smaller. Keep extending the prefix until that - // condition holds, or we run off the end of the string. - // scalastyle:off nonascii - val tieBreaker = '\ufffd' - var ans = x.take(prefixLen) + x.substring(prefixLen).takeWhile(_ >= tieBreaker) - // Append a tie-breaker only if we truncated any characters from input string `x`. - if (ans.length < x.length) { - ans = ans + tieBreaker + def truncateMaxStringAgg(prefixLen: Int)(originalMax: String): String = { + if (originalMax == null || originalMax.length <= prefixLen) { + return originalMax + } + if (prefixLen == 0) { + return null + } + + // Grab the prefix. We want to append max Unicode code point `\uDBFF\uDFFF` as a tie-breaker, + // but that is only safe if the character we truncated was smaller in UTF-8 encoded binary + // comparison. Keep extending the prefix until that condition holds, or we run off the end of + // the string. + // We also try to use the ASCII max character `\u007F` as a tie-breaker if possible. + val maxLen = getExpansionLimit(prefixLen) + // Start with a valid prefix + var currLen = truncateMinStringAgg(prefixLen)(originalMax).length + while (currLen <= maxLen) { + if (currLen >= originalMax.length) { + // Return originalMax if we have reached the end of the string + return originalMax + } else if (currLen + 1 < originalMax.length && + originalMax.substring(currLen, currLen + 2) == UTF8_MAX_CHARACTER) { + // Skip the UTF-8 max character. It occupies two characters in a Scala string. + currLen += 2 + } else if (originalMax.charAt(currLen) < ASCII_MAX_CHARACTER) { + return originalMax.take(currLen) + ASCII_MAX_CHARACTER + } else { + return originalMax.take(currLen) + UTF8_MAX_CHARACTER } - ans - // scalastyle:off nonascii } + + // Return null when the input string is too long to truncate. + null + // scalastyle:off nonascii } + + /** + * Calculates the upper character limit when constructing a maximum is not possible with only + * prefixLen chars. + */ + private def getExpansionLimit(prefixLen: Int): Int = 2 * prefixLen } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala index 1dc1525f948..f48a0f140b5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala @@ -514,7 +514,7 @@ abstract class StatsCollector( // the max, check the helper function for more details. StatisticsCollection.truncateMaxStringAgg(stringTruncateLength.get)(rawString) } else { - rawString.substring(0, stringTruncateLength.get) + StatisticsCollection.truncateMinStringAgg(stringTruncateLength.get)(rawString) } } else { rawString diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/stats/StatsCollectionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/stats/StatsCollectionSuite.scala index ace5f184faf..ca2cf1e6188 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/stats/StatsCollectionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/stats/StatsCollectionSuite.scala @@ -24,6 +24,7 @@ import java.time.LocalDateTime import org.apache.spark.sql.delta._ import org.apache.spark.sql.delta.actions.Protocol import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.stats.StatisticsCollection.{ASCII_MAX_CHARACTER, UTF8_MAX_CHARACTER} import org.apache.spark.sql.delta.test.{DeltaExceptionTestUtils, DeltaSQLCommandTest, DeltaSQLTestUtils, TestsStatistics} import org.apache.spark.sql.delta.test.DeltaTestImplicits._ import org.apache.spark.sql.delta.util.JsonUtils @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSche import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String class StatsCollectionSuite extends QueryTest @@ -359,31 +361,73 @@ class StatsCollectionSuite } } + test("Truncate min string") { + // scalastyle:off nonascii + val inputToExpected = Seq( + (s"abcd", s"abc", 3), + (s"abcdef", s"abcdef", 6), + (s"abcde�", s"abcde�", 6), + (s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", + s"$UTF8_MAX_CHARACTER", + 1), + (s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", s"$UTF8_MAX_CHARACTER", 1), + (s"abcd", null, 0) + ) + + inputToExpected.foreach { + case (input, expected, prefixLen) => + val actual = StatisticsCollection.truncateMinStringAgg(prefixLen)(input) + val debugMsg = s"input:$input, actual:$actual, expected:$expected" + assert(actual == expected, debugMsg) + if (actual != null) { + assert(input.startsWith(actual), debugMsg) + } + } + // scalastyle:on nonascii + } + test("Truncate max string") { // scalastyle:off nonascii - val prefixLen = 6 - // � is the max unicode character with value \ufffd val inputToExpected = Seq( - (s"abcd", s"abcd"), - (s"abcdef", s"abcdef"), - (s"abcde�", s"abcde�"), - (s"abcd�abcd", s"abcd�a�"), - (s"�abcd", s"�abcd"), - (s"abcdef�", s"abcdef�"), - (s"abcdef-abcdef�", s"abcdef�"), - (s"abcdef�abcdef", s"abcdef��"), - (s"abcdef��abcdef", s"abcdef���"), - (s"abcdef�abcdef�abcdef�abcdef", s"abcdef��"), - (s"漢字仮名한글தமி", s"漢字仮名한글�"), - (s"漢字仮名한글��", s"漢字仮名한글��"), - (s"漢字仮名한글", s"漢字仮名한글") + (s"abcd", null, 0), + (s"a${UTF8_MAX_CHARACTER}d", s"a$UTF8_MAX_CHARACTER$ASCII_MAX_CHARACTER", 2), + (s"abcd", s"abcd", 6), + (s"abcdef", s"abcdef", 6), + (s"abcde�", s"abcde�", 6), + (s"abcd�abcd", s"abcd�a$ASCII_MAX_CHARACTER", 6), + (s"�abcd", s"�abcd", 6), + (s"abcdef�", s"abcdef$UTF8_MAX_CHARACTER", 6), + (s"abcdef��", s"abcdef$UTF8_MAX_CHARACTER", 6), + (s"abcdef-abcdef�", s"abcdef$ASCII_MAX_CHARACTER", 6), + (s"abcdef�abcdef", s"abcdef$UTF8_MAX_CHARACTER", 6), + (s"abcde�abcdef�abcdef�abcdef", s"abcde�$ASCII_MAX_CHARACTER", 6), + (s"漢字仮名한글தமி", s"漢字仮名한글$UTF8_MAX_CHARACTER", 6), + (s"漢字仮名한글��", s"漢字仮名한글$UTF8_MAX_CHARACTER", 6), + (s"漢字仮名한글", s"漢字仮名한글", 6), + (s"abcdef🚀", s"abcdef$UTF8_MAX_CHARACTER", 6), + (s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", null, 1), + (s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", + s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", + 4), + (s"����", s"��$UTF8_MAX_CHARACTER", 2), + (s"���", s"�$UTF8_MAX_CHARACTER", 1), + ("abcdefghijklm💞😉💕\n🥀🌹💐🌺🌷🌼🌻🌷🥀", + s"abcdefghijklm💞😉💕\n🥀🌹💐🌺🌷🌼$UTF8_MAX_CHARACTER", + 32) ) + inputToExpected.foreach { - case (input, expected) => + case (input, expected, prefixLen) => val actual = StatisticsCollection.truncateMaxStringAgg(prefixLen)(input) - assert(actual == expected, s"input:$input, actual:$actual, expected:$expected") + // `Actual` should be higher or equal than `input` in UTF-8 encoded binary order. + val debugMsg = s"input:$input, actual:$actual, expected:$expected" + assert(actual == expected, debugMsg) + if (actual != null) { + assert(UTF8String.fromString(input).compareTo(UTF8String.fromString(actual)) <= + 0, debugMsg) + } } - // scalastyle:off nonascii + // scalastyle:on nonascii }