Skip to content

Commit

Permalink
Parallel calls
Browse files Browse the repository at this point in the history
  • Loading branch information
sumeet-db committed Jul 29, 2024
1 parent ec3f6be commit aeaecf8
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ case class DeltaStatsColumnSpec(
}

object StatisticsCollection extends DeltaCommand {

val ASCII_MAX_CHARACTER = '\u007F'

val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))

/**
* The SQL grammar already includes a `multipartIdentifierList` rule for parsing a string into a
* list of multi-part identifiers. We just expose it here, with a custom parser and AstBuilder.
Expand Down Expand Up @@ -769,28 +774,69 @@ object StatisticsCollection extends DeltaCommand {
txn.commit(newAddFiles, ComputeStats(predicates))
}

def truncateMinStringAgg(prefixLen: Int)(input: String): String = {
if (input == null || input.length <= prefixLen) {
return input
}
if (prefixLen == 0) {
return null
}
if (Character.isHighSurrogate(input.charAt(prefixLen - 1)) &&
Character.isLowSurrogate(input.charAt(prefixLen))) {
// If the character at prefixLen - 1 is a high surrogate and the next character is a low
// surrogate, we need to include the next character in the prefix to ensure that we don't
// truncate the string in the middle of a surrogate pair.
input.take(prefixLen + 1)
} else {
input.take(prefixLen)
}
}

/**
* Helper method to truncate the input string `x` to the given `prefixLen` length, while also
* appending the unicode max character to the end of the truncated string. This ensures that any
* value in this column is less than or equal to the max.
* Note: Input string `x` must be properly encoded in UTF-8.
* Helper method to truncate the input string `input` to the given `prefixLen` length, while also
* ensuring the any value in this column is less than or equal to the truncated max in UTF-8
* encoding.
*/
def truncateMaxStringAgg(prefixLen: Int)(x: String): String = {
if (x == null || x.length <= prefixLen) {
x
} else {
// Grab the prefix. We want to append `\ufffd` as a tie-breaker, but that is only safe
// if the character we truncated was smaller. Keep extending the prefix until that
// condition holds, or we run off the end of the string.
// scalastyle:off nonascii
val tieBreaker = '\ufffd'
var ans = x.take(prefixLen) + x.substring(prefixLen).takeWhile(_ >= tieBreaker)
// Append a tie-breaker only if we truncated any characters from input string `x`.
if (ans.length < x.length) {
ans = ans + tieBreaker
def truncateMaxStringAgg(prefixLen: Int)(originalMax: String): String = {
if (originalMax == null || originalMax.length <= prefixLen) {
return originalMax
}
if (prefixLen == 0) {
return null
}

// scalastyle:off nonascii
// Grab the prefix. We want to append max Unicode code point `\uDBFF\uDFFF` as a tie-breaker,
// but that is only safe if the character we truncated was smaller in UTF-8 encoded binary
// comparison. Keep extending the prefix until that condition holds, or we run off the end of
// the string.
// We also try to use the ASCII max character `\u007F` as a tie-breaker if possible.
val maxLen = getExpansionLimit(prefixLen)
// Start with a valid prefix
var currLen = truncateMinStringAgg(prefixLen)(originalMax).length
while (currLen <= maxLen) {
if (currLen >= originalMax.length) {
// Return originalMax if we have reached the end of the string
return originalMax
} else if (currLen + 1 < originalMax.length &&
originalMax.substring(currLen, currLen + 2) == UTF8_MAX_CHARACTER) {
// Skip the UTF-8 max character. It occupies two characters in a Scala string.
currLen += 2
} else if (originalMax.charAt(currLen) < ASCII_MAX_CHARACTER) {
return originalMax.take(currLen) + ASCII_MAX_CHARACTER
} else {
return originalMax.take(currLen) + UTF8_MAX_CHARACTER
}
ans
// scalastyle:off nonascii
}

// Return null when the input string is too long to truncate.
null
// scalastyle:off nonascii
}

/**
* Calculates the upper character limit when constructing a maximum is not possible with only
* prefixLen chars.
*/
private def getExpansionLimit(prefixLen: Int): Int = 2 * prefixLen
}
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ abstract class StatsCollector(
// the max, check the helper function for more details.
StatisticsCollection.truncateMaxStringAgg(stringTruncateLength.get)(rawString)
} else {
rawString.substring(0, stringTruncateLength.get)
StatisticsCollection.truncateMinStringAgg(stringTruncateLength.get)(rawString)
}
} else {
rawString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.time.LocalDateTime
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.stats.StatisticsCollection.{ASCII_MAX_CHARACTER, UTF8_MAX_CHARACTER}
import org.apache.spark.sql.delta.test.{DeltaExceptionTestUtils, DeltaSQLCommandTest, DeltaSQLTestUtils, TestsStatistics}
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
import org.apache.spark.sql.delta.util.JsonUtils
Expand All @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSche
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

class StatsCollectionSuite
extends QueryTest
Expand Down Expand Up @@ -359,29 +361,71 @@ class StatsCollectionSuite
}
}

test("Truncate min string") {
// scalastyle:off nonascii
val inputToExpected = Seq(
(s"abcd", s"abc", 3),
(s"abcdef", s"abcdef", 6),
(s"abcde�", s"abcde�", 6),
(s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER",
s"$UTF8_MAX_CHARACTER",
1),
(s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", s"$UTF8_MAX_CHARACTER", 1),
(s"abcd", null, 0)
)

inputToExpected.foreach {
case (input, expected, prefixLen) =>
val actual = StatisticsCollection.truncateMinStringAgg(prefixLen)(input)
val debugMsg = s"input:$input, actual:$actual, expected:$expected"
assert(actual == expected, debugMsg)
if (actual != null) {
assert(input.startsWith(actual), debugMsg)
}
}
// scalastyle:on nonascii
}

test("Truncate max string") {
// scalastyle:off nonascii
val prefixLen = 6
// � is the max unicode character with value \ufffd
val inputToExpected = Seq(
(s"abcd", s"abcd"),
(s"abcdef", s"abcdef"),
(s"abcde�", s"abcde�"),
(s"abcd�abcd", s"abcd�a�"),
(s"�abcd", s"�abcd"),
(s"abcdef�", s"abcdef�"),
(s"abcdef-abcdef�", s"abcdef�"),
(s"abcdef�abcdef", s"abcdef��"),
(s"abcdef��abcdef", s"abcdef���"),
(s"abcdef�abcdef�abcdef�abcdef", s"abcdef��"),
(s"漢字仮名한글தமி", s"漢字仮名한글�"),
(s"漢字仮名한글��", s"漢字仮名한글��"),
(s"漢字仮名한글", s"漢字仮名한글")
(s"abcd", null, 0),
(s"a${UTF8_MAX_CHARACTER}d", s"a$UTF8_MAX_CHARACTER$ASCII_MAX_CHARACTER", 2),
(s"abcd", s"abcd", 6),
(s"abcdef", s"abcdef", 6),
(s"abcde�", s"abcde�", 6),
(s"abcd�abcd", s"abcd�a$ASCII_MAX_CHARACTER", 6),
(s"�abcd", s"�abcd", 6),
(s"abcdef�", s"abcdef$UTF8_MAX_CHARACTER", 6),
(s"abcdef��", s"abcdef$UTF8_MAX_CHARACTER", 6),
(s"abcdef-abcdef�", s"abcdef$ASCII_MAX_CHARACTER", 6),
(s"abcdef�abcdef", s"abcdef$UTF8_MAX_CHARACTER", 6),
(s"abcde�abcdef�abcdef�abcdef", s"abcde�$ASCII_MAX_CHARACTER", 6),
(s"漢字仮名한글தமி", s"漢字仮名한글$UTF8_MAX_CHARACTER", 6),
(s"漢字仮名한글��", s"漢字仮名한글$UTF8_MAX_CHARACTER", 6),
(s"漢字仮名한글", s"漢字仮名한글", 6),
(s"abcdef🚀", s"abcdef$UTF8_MAX_CHARACTER", 6),
(s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER", null, 1),
(s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER",
s"$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER$UTF8_MAX_CHARACTER",
4),
(s"����", s"��$UTF8_MAX_CHARACTER", 2),
(s"���", s"$UTF8_MAX_CHARACTER", 1),
("abcdefghijklm💞😉💕\n🥀🌹💐🌺🌷🌼🌻🌷🥀",
s"abcdefghijklm💞😉💕\n🥀🌹💐🌺🌷🌼$UTF8_MAX_CHARACTER",
32)
)

inputToExpected.foreach {
case (input, expected) =>
case (input, expected, prefixLen) =>
val actual = StatisticsCollection.truncateMaxStringAgg(prefixLen)(input)
assert(actual == expected, s"input:$input, actual:$actual, expected:$expected")
// `Actual` should be higher or equal than `input` in UTF-8 encoded binary order.
val debugMsg = s"input:$input, actual:$actual, expected:$expected"
assert(actual == expected, debugMsg)
if (actual != null) {
assert(UTF8String.fromString(input).compareTo(UTF8String.fromString(actual)) <=
0, debugMsg)
}
}
// scalastyle:off nonascii
}
Expand Down

0 comments on commit aeaecf8

Please # to comment.