From 5292cd55cef74172e7ffbe9308f385bdf39983fb Mon Sep 17 00:00:00 2001 From: NVnavkumar <97137715+NVnavkumar@users.noreply.github.com> Date: Fri, 4 Mar 2022 12:20:05 -0800 Subject: [PATCH] Support for hexadecimal digits in regular expressions on the GPU (#4869) * Enable limited support for hexadecimal characters in regular expressions on the GPU Signed-off-by: Navin Kumar * Update references to issues regarded hex and octal in character classes, and update compatibility docs Signed-off-by: Navin Kumar * Correct bug in exception copy Signed-off-by: Navin Kumar --- docs/compatibility.md | 3 +- .../com/nvidia/spark/rapids/RegexParser.scala | 36 +++++++++++++++---- .../rapids/RegularExpressionParserSuite.scala | 5 +++ .../RegularExpressionTranspilerSuite.scala | 35 +++++++++++++----- 4 files changed, 62 insertions(+), 17 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 586638495fe..600e2e25a63 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -550,7 +550,8 @@ Here are some examples of regular expression patterns that are not supported on - Regular expressions containing null characters (unless the pattern is a simple literal string) - Octal digits in the range `\0200` to `\0377` - Character classes with octal digits, such as `[\02]` or `[\024]` -- Hex digits +- Character classes with hex digits, such as `[\x02]` or `[\x24]` +- Hex digits in the range `\x80` to `Character.MAX_CODE_POINT` - `regexp_replace` does not support back-references Work is ongoing to increase the range of regular expressions that can run on the GPU. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 9c05cbf9c67..bae02e5af10 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -302,13 +302,18 @@ class RegexParser(pattern: String) { // \x{h...h} The character with hexadecimal value 0xh...h // (Character.MIN_CODE_POINT <= 0xh...h <= Character.MAX_CODE_POINT) + val varHex = pattern.charAt(pos) == '{' + if (varHex) { + consumeExpected('{') + } val start = pos while (!eof() && isHexDigit(pattern.charAt(pos))) { pos += 1 } val hexDigit = pattern.substring(start, pos) - - if (hexDigit.length < 2) { + if (varHex) { + consumeExpected('}') + } else if (hexDigit.length != 2) { throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") } @@ -554,15 +559,20 @@ class CudfRegexTranspiler(mode: RegexMode) { digits } if (Integer.parseInt(octal, 8) >= 128) { + // see https://github.com/NVIDIA/spark-rapids/issues/4746 throw new RegexUnsupportedException( "cuDF does not support octal digits 0o177 < n <= 0o377") } RegexOctalChar(octal) - case RegexHexDigit(_) => - // see https://github.com/NVIDIA/spark-rapids/issues/4486 - throw new RegexUnsupportedException( - s"cuDF does not support hex digits consistently with Spark") + case RegexHexDigit(digits) => + val codePoint = Integer.parseInt(digits, 16) + if (codePoint >= 128) { + // see https://github.com/NVIDIA/spark-rapids/issues/4866 + throw new RegexUnsupportedException( + "cuDF does not support hex digits > 0x7F") + } + RegexHexDigit(String.format("%02x", Int.box(codePoint))) case RegexEscaped(ch) => ch match { case 'D' => @@ -609,10 +619,16 @@ class CudfRegexTranspiler(mode: RegexMode) { // - "[a-b[c-d]]" is supported by Java but not cuDF throw new RegexUnsupportedException("nested character classes are not supported") case RegexEscaped(ch) if ch == '0' => + // see https://github.com/NVIDIA/spark-rapids/issues/4862 // examples // - "[\02] should match the character with code point 2" throw new RegexUnsupportedException( "cuDF does not support octal digits in character classes") + case RegexEscaped(ch) if ch == 'x' => + // examples + // - "[\x02] should match the character with code point 2" + throw new RegexUnsupportedException( + "cuDF does not support hex digits in character classes") case _ => } val components: Seq[RegexCharacterClassComponent] = characters @@ -828,7 +844,13 @@ sealed trait RegexCharacterClassComponent extends RegexAST sealed case class RegexHexDigit(a: String) extends RegexCharacterClassComponent { override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"\\x$a" + override def toRegexString: String = { + if (a.length == 2) { + s"\\x$a" + } else { + s"\\x{$a}" + } + } } sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 06c642ba561..df092485d94 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -124,6 +124,11 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer(RegexHexDigit("FF")))) } + test("variable length hex digit") { + assert(parse(raw"\x{ABC}") === + RegexSequence(ListBuffer(RegexHexDigit("ABC")))) + } + test("octal digit") { val digits = Seq("0", "01", "076", "077", "0123", "0177", "0377") for (digit <- digits) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 86bbfebb339..fa72edff2fa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -144,13 +144,23 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("cuDF does not support octal digits 0o177 < n <= 0o377") { + // see https://github.com/NVIDIA/spark-rapids/issues/4746 val patterns = Seq(raw"\0200", raw"\0377") patterns.foreach(pattern => assertUnsupported(pattern, RegexFindMode, "cuDF does not support octal digits 0o177 < n <= 0o377")) } + test("cuDF does not support hex digits > 0x7F") { + // see https://github.com/NVIDIA/spark-rapids/issues/4866 + val patterns = Seq(raw"\x80", raw"\xff", raw"\xFF", raw"\x{ABC}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, + "cuDF does not support hex digits > 0x7F")) + } + test("cuDF does not support octal digits in character classes") { + // see https://github.com/NVIDIA/spark-rapids/issues/4862 val patterns = Seq(raw"[\02]", raw"[\012]", raw"[\0177]") patterns.foreach(pattern => assertUnsupported(pattern, RegexFindMode, @@ -159,19 +169,26 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { ) } - test("cuDF does not support hex digits consistently with Spark") { - // see https://github.com/NVIDIA/spark-rapids/issues/4486 - val patterns = Seq(raw"\xA9", raw"\x00A9", raw"\x10FFFF") + test("cuDF does not support hex digits in character classes") { + // see https://github.com/NVIDIA/spark-rapids/issues/4865 + val patterns = Seq(raw"[\x02]", raw"[\x2c]", raw"[\x7f]") patterns.foreach(pattern => assertUnsupported(pattern, RegexFindMode, - "cuDF does not support hex digits consistently with Spark")) + "cuDF does not support hex digits in character classes" + ) + ) } test("octal digits < 0o177 - find") { - // val patterns = Seq(raw"\07", raw"\077", raw"\0177", raw"\0377") - val patterns = Seq(raw"\07", raw"\077", raw"\0177") + val patterns = Seq(raw"\07", raw"\077", raw"\0177", raw"\01772") + assertCpuGpuMatchesRegexpFind(patterns, Seq("", "\u0007", "a\u0007b", + "\u0007\u003f\u007f", "\u007f", "\u007f2")) + } + + test("hex digits < 0x7F - find") { + val patterns = Seq(raw"\x07", raw"\x3f", raw"\x7F", raw"\x7f", raw"\x{7}", raw"\x{0007f}") assertCpuGpuMatchesRegexpFind(patterns, Seq("", "\u0007", "a\u0007b", - "\u0007\u003f\u007f", "\u007f")) + "\u0007\u003f\u007f", "\u007f", "\u007f2")) } test("string anchors - find") { @@ -772,8 +789,8 @@ class FuzzRegExp(suggestedChars: String, skipKnownIssues: Boolean = true) { } else { baseGenerators ++ Seq( () => escapedChar, // https://github.com/NVIDIA/spark-rapids/issues/4505 - () => hexDigit, // https://github.com/NVIDIA/spark-rapids/issues/4486 - () => octalDigit) // https://github.com/NVIDIA/spark-rapids/issues/4409 + () => hexDigit, // https://github.com/NVIDIA/spark-rapids/issues/4865 + () => octalDigit) // https://github.com/NVIDIA/spark-rapids/issues/4862 } generators(rr.nextInt(generators.length))() }