diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index b5bbd4a7d31..2c41eb2dd15 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -858,15 +858,67 @@ class CudfRegexTranspiler(mode: RegexMode) { throw new RegexUnsupportedException( "regex_replace and regex_split on GPU do not support repetition with {0}") - case (RegexGroup(_, term), SimpleQuantifier(ch)) + case (RegexGroup(capture, term), SimpleQuantifier(ch)) if "+*".contains(ch) && !isSupportedRepetitionBase(term) => - throw new RegexUnsupportedException(nothingToRepeat) - case (RegexGroup(_, term), QuantifierVariableLength(_, None)) + (term, ch) match { + // \Z is not supported in groups + case (RegexEscaped('A'), '+') | + (RegexSequence(ListBuffer(RegexEscaped('A'))), '+') => + // (\A)+ can be transpiled to (\A) (dropping the repetition) + // we use rewrite(...) here to handle logic regarding modes + // (\A is not supported in RegexSplitMode) + RegexGroup(capture, rewrite(term, previous)) + // NOTE: (\A)* can be transpiled to (\A)? + // however, (\A)? is not supported in libcudf yet + case _ => + throw new RegexUnsupportedException(nothingToRepeat) + } + case (RegexGroup(capture, term), QuantifierVariableLength(n, _)) if !isSupportedRepetitionBase(term) => - // specifically this variable length repetition: \A{2,} - throw new RegexUnsupportedException(nothingToRepeat) + term match { + // \Z is not supported in groups + case RegexEscaped('A') | RegexSequence(ListBuffer(RegexEscaped('A'))) if n > 0 => + // (\A){1,} can be transpiled to (\A) (dropping the repetition) + // we use rewrite(...) here to handle logic regarding modes + // (\A is not supported in RegexSplitMode) + RegexGroup(capture, rewrite(term, previous)) + // NOTE: (\A)* can be transpiled to (\A)? + // however, (\A)? is not supported in libcudf yet + case _ => + throw new RegexUnsupportedException(nothingToRepeat) + } + case (RegexGroup(capture, term), QuantifierFixedLength(n)) + if !isSupportedRepetitionBase(term) => + term match { + // \Z is not supported in groups + case RegexEscaped('A') | RegexSequence(ListBuffer(RegexEscaped('A'))) if n > 0 => + // (\A){1,} can be transpiled to (\A) (dropping the repetition) + // we use rewrite(...) here to handle logic regarding modes + // (\A is not supported in RegexSplitMode) + RegexGroup(capture, rewrite(term, previous)) + // NOTE: (\A)* can be transpiled to (\A)? + // however, (\A)? is not supported in libcudf yet + case _ => + throw new RegexUnsupportedException(nothingToRepeat) + } case (RegexGroup(_, _), SimpleQuantifier(ch)) if ch == '?' => RegexRepetition(rewrite(base, None), quantifier) + case (RegexEscaped(ch), SimpleQuantifier('+')) if "AZ".contains(ch) => + // \A+ can be transpiled to \A (dropping the repetition) + // \Z+ can be transpiled to \Z (dropping the repetition) + // we use rewrite(...) here to handle logic regarding modes + // (\A and \Z are not supported in RegexSplitMode) + rewrite(base, previous) + // NOTE: \A* can be transpiled to \A? + // however, \A? is not supported in libcudf yet + case (RegexEscaped(ch), QuantifierFixedLength(n)) if n > 0 && "AZ".contains(ch) => + // \A{2} can be transpiled to \A (dropping the repetition) + // \Z{2} can be transpiled to \Z (dropping the repetition) + rewrite(base, previous) + case (RegexEscaped(ch), QuantifierVariableLength(n,_)) if n > 0 && "AZ".contains(ch) => + // \A{1,5} can be transpiled to \A (dropping the repetition) + // \Z{1,} can be transpiled to \Z (dropping the repetition) + rewrite(base, previous) case _ if isSupportedRepetitionBase(base) => RegexRepetition(rewrite(base, None), quantifier) case _ => @@ -913,7 +965,7 @@ class CudfRegexTranspiler(mode: RegexMode) { case RegexChoice(l, r) => isBeginOrEndLineAnchor(l) && isBeginOrEndLineAnchor(r) case RegexRepetition(term, _) => isBeginOrEndLineAnchor(term) case RegexChar(ch) => ch == '^' || ch == '$' - case RegexEscaped('z') => true // \z gets translated to $ + case RegexEscaped(ch) if "zZ".contains(ch) => true // \z gets translated to $ case _ => false } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 12803933679..0085625cc52 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -191,13 +191,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("string anchors - find") { - val patterns = Seq("\\Atest", "test\\z") + val patterns = Seq("\\Atest", "\\A+test", "\\A{1}test", "\\A{1,}test", + "(\\A)+test", "(\\A){1}test", "(\\A){1,}test", "test\\z") assertCpuGpuMatchesRegexpFind(patterns, Seq("", "test", "atest", "testa", "\ntest", "test\n", "\ntest\n")) } test("string anchor \\A will fall back to CPU in some repetitions") { - val patterns = Seq(raw"(\A)+", raw"(\A)*", raw"(\A){2,}") + val patterns = Seq(raw"(\A)*a", raw"(\A){0,}a", raw"(\A){0}a") patterns.foreach(pattern => assertUnsupported(pattern, RegexFindMode, "nothing to repeat") ) @@ -205,10 +206,25 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("string anchor \\Z fall back to CPU - replace or split") { for (mode <- Seq(RegexReplaceMode, RegexSplitMode)) { - assertUnsupported("\\Z", mode, "string anchor \\Z is not supported in split or replace mode") + assertUnsupported("a\\Z", mode, "string anchor \\Z is not supported in split or replace mode") } } + test("string anchor \\Z fall back to CPU in groups") { + val patterns = Seq(raw"(\Z)", raw"(\Z)+") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, + "sequences that only contain '^' or '$' are not supported") + ) + } + + test("string anchor \\Z fall back to CPU in some repetitions") { + val patterns = Seq(raw"a(\Z)*", raw"a(\Z){2,}") + patterns.foreach(pattern => + assertUnsupported(pattern, RegexFindMode, "nothing to repeat") + ) + } + test("string anchors - replace") { val patterns = Seq("\\Atest") assertCpuGpuMatchesRegexpReplace(patterns, Seq("", "test", "atest", "testa", @@ -236,7 +252,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("string anchor \\Z - find") { val patterns = Seq("\\Z\r", "a\\Z", "\r\\Z", "\f\\Z", "\\Z\f", "\u0085\\Z", "\u2028\\Z", - "\u2029\\Z", "\n\\Z", "\r\n\\Z", "[\r\n]?\\Z", "\\00*[D$3]\\Z", "a\\Zb") + "\u2029\\Z", "\n\\Z", "\r\n\\Z", "[\r\n]?\\Z", "\\00*[D$3]\\Z", "a\\Zb", "a\\Z+") val inputs = Seq("a", "a\n", "a\r", "a\r\n", "a\u0085\n", "a\f", "\f", "\r", "\u0085", "\u2028", "\u2029", "\n", "\r\n", "\r\n\r", "\r\n\u0085", "\u0085\r", "\u2028\n", "\u2029\n", "\n\r", "\n\u0085", "\n\u2028", "\n\u2029", "2+|+??wD\n", "a\r\nb") @@ -298,6 +314,13 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { .replaceAll("\\\\z", "\\$")) } + test("transpile \\A repetitions") { + doTranspileTest("a\\A+", "a\\A") + doTranspileTest("a\\A{1,}", "a\\A") + doTranspileTest("a\\A{2}", "a\\A") + doTranspileTest("a(\\A)+", "a(\\A)") + } + test("transpile \\z") { doTranspileTest("abc\\z", "abc$") } @@ -316,6 +339,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { doTranspileTest("]\\Z\r", "]\r[\n\u0085\u2028\u2029]?$") doTranspileTest("^\\Z[^*A-ZA-Z]", "^[\n\r\u0085\u2028\u2029]$") doTranspileTest("^\\Z([^*A-ZA-Z])", "^([\n\r\u0085\u2028\u2029])$") + doTranspileTest("a\\Z+", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") + doTranspileTest("a\\Z{1}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") + doTranspileTest("a\\Z{1,}", "a(?:[\n\r\u0085\u2028\u2029]|\r\n)?$") } test("compare CPU and GPU: character range including unescaped + and -") {