diff --git a/docs/compatibility.md b/docs/compatibility.md index 17e642f344c..df026d05a6b 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -257,14 +257,32 @@ The plugin supports reading `uncompressed`, `snappy` and `gzip` Parquet files an fall back to the CPU when reading an unsupported compression format, and will error out in that case. -## Regular Expressions -The RAPIDS Accelerator for Apache Spark currently supports string literal matches, not wildcard -matches. +## LIKE If a null char '\0' is in a string that is being matched by a regular expression, `LIKE` sees it as the end of the string. This will be fixed in a future release. The issue is [here](https://github.com/NVIDIA/spark-rapids/issues/119). +## Regular Expressions + +### regexp_replace + +The RAPIDS Accelerator for Apache Spark currently supports string literal matches, not wildcard +matches for the `regexp_replace` function and will fall back to CPU if a regular expression pattern +is provided. + +### RLike + +The GPU implementation of `RLike` has the following known issues where behavior is not consistent with Apache Spark and +this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. + +- `.` matches `\r` on the GPU but not on the CPU ([cuDF issue #9619](https://github.com/rapidsai/cudf/issues/9619)) +- `$` does not match the end of string if the string ends with a line-terminator + ([cuDF issue #9620](https://github.com/rapidsai/cudf/issues/9620)) + +`RLike` will fall back to CPU if any regular expressions are detected that are not supported on the GPU +or would produce different results on the GPU. + ## Timestamps Spark stores timestamps internally relative to the JVM time zone. Converting an arbitrary timestamp @@ -569,60 +587,6 @@ distribution. Because the results are not bit-for-bit identical with the Apache `approximate_percentile`, this feature is disabled by default and can be enabled by setting `spark.rapids.sql.expression.ApproximatePercentile=true`. -## RLike - -The GPU implementation of RLike has a number of known issues where behavior is not consistent with Apache Spark and -this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. - -A summary of known issues is shown below but this is not intended to be a comprehensive list. We recommend that you -do your own testing to verify whether the GPU implementation of `RLike` is suitable for your use case. - -We plan on improving the RLike functionality over time to make it more compatible with Spark so this feature should -be used at your own risk with the expectation that the behavior will change in future releases. - -### Multi-line handling - -The GPU implementation of RLike supports `^` and `$` to represent the start and end of lines within a string but -Spark uses `^` and `$` to refer to the start and end of the entire string (equivalent to `\A` and `\Z`). - -| Pattern | Input | Spark on CPU | Spark on GPU | -|---------|--------|--------------|--------------| -| `^A` | `A\nB` | Match | Match | -| `A$` | `A\nB` | No Match | Match | -| `^B` | `A\nB` | No Match | Match | -| `B$` | `A\nB` | Match | Match | - -As a workaround, `\A` and `\Z` can be used instead of `^` and `$`. - -### Null support - -The GPU implementation of RLike supports null characters in the input but does not support null characters in -the regular expression and will fall back to the CPU in this case. - -### Qualifiers with nothing to repeat - -Spark supports qualifiers in cases where there is nothing to repeat. For example, Spark supports `a*+` and this -will match all inputs. The GPU implementation of RLike does not support this syntax and will throw an exception with -the message `nothing to repeat at position 0`. - -### Stricter escaping requirements - -The GPU implementation of RLike has stricter requirements around escaping special characters in some cases. - -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `a[-+]` | `a-` | Match | No Match | -| `a[\-\+]` | `a-` | Match | Match | - -### Empty groups - -The GPU implementation of RLike does not support empty groups correctly. - -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `z()?` | `a` | No Match | Match | -| `z()*` | `a` | No Match | Match | - ## Conditionals and operations with side effects (ANSI mode) In Apache Spark condition operations like `if`, `coalesce`, and `case/when` lazily evaluate diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 792d870580b..cfd0010d2ef 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -492,7 +492,7 @@ def test_rlike_embedded_null(): conf={'spark.rapids.sql.expression.RLike': 'true'}) @allow_non_gpu('ProjectExec', 'RLike') -def test_rlike_null_pattern(): +def test_rlike_fallback_null_pattern(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( @@ -500,6 +500,15 @@ def test_rlike_null_pattern(): 'RLike', conf={'spark.rapids.sql.expression.RLike': 'true'}) +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_empty_group(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a()?"'), + 'RLike', + conf={'spark.rapids.sql.expression.RLike': 'true'}) + def test_rlike_escape(): gen = mk_str_gen('[ab]{0,2}[\\-\\+]{0,2}') assert_gpu_and_cpu_are_equal_collect( @@ -507,7 +516,6 @@ def test_rlike_escape(): 'a rlike "a[\\\\-]"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF supports multiline by default but Spark does not - https://github.com/rapidsai/cudf/issues/9439') def test_rlike_multi_line(): gen = mk_str_gen('[abc]\n[def]') assert_gpu_and_cpu_are_equal_collect( @@ -518,18 +526,20 @@ def test_rlike_multi_line(): 'a rlike "e$"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF has stricter requirements around escaping - https://github.com/rapidsai/cudf/issues/9434') def test_rlike_missing_escape(): gen = mk_str_gen('a[\\-\\+]') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a[-]"'), + 'a rlike "a[-]"', + 'a rlike "a[+-]"', + 'a rlike "a[a-b-]"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF does not support qualifier with nothing to repeat - https://github.com/rapidsai/cudf/issues/9434') -def test_rlike_nothing_to_repeat(): +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_possessive_quantifier(): gen = mk_str_gen('(\u20ac|\\w){0,3}a[|b*.$\r\n]{0,2}c\\w{0,3}') - assert_gpu_and_cpu_are_equal_collect( + assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a rlike "a*+"'), + 'RLike', conf={'spark.rapids.sql.expression.RLike': 'true'}) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala new file mode 100644 index 00000000000..bfc12b865fc --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -0,0 +1,645 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.sql.SQLException + +import scala.collection.mutable.ListBuffer + +/** + * Regular expression parser based on a Pratt Parser design. + * + * The goal of this parser is to build a minimal AST that allows us + * to validate that we can support the expression on the GPU. The goal + * is not to parse with the level of detail that would be required if + * we were building an evaluation engine. For example, operator precedence is + * largely ignored but could be added if we need it later. + * + * The Java and cuDF regular expression documentation has been used as a reference: + * + * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html + * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html + * + * The following blog posts provide some background on Pratt Parsers and parsing regex. + * + * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ + * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ + */ +class RegexParser(pattern: String) { + + /** index of current position within the string being parsed */ + private var pos = 0 + + def parse(): RegexAST = { + val ast = parseInternal() + if (!eof()) { + throw new RegexUnsupportedException("failed to parse full regex") + } + ast + } + + private def parseInternal(): RegexAST = { + val term = parseTerm(() => peek().contains('|')) + if (!eof() && peek().contains('|')) { + consumeExpected('|') + RegexChoice(term, parseInternal()) + } else { + term + } + } + + private def parseTerm(until: () => Boolean): RegexAST = { + val sequence = RegexSequence(new ListBuffer()) + while (!eof() && !until()) { + parseFactor() match { + case RegexSequence(parts) => + sequence.parts ++= parts + case other => + sequence.parts += other + } + } + sequence + } + + private def isValidQuantifierAhead(): Boolean = { + if (peek().contains('{')) { + val bookmark = pos + consumeExpected('{') + val q = parseQuantifierOrLiteralBrace() + pos = bookmark + q match { + case _: QuantifierFixedLength | _: QuantifierVariableLength => true + case _ => false + } + } else { + false + } + } + + private def parseFactor(): RegexAST = { + var base = parseBase() + while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') + || isValidQuantifierAhead())) { + + val quantifier = if (peek().contains('{')) { + consumeExpected('{') + parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier] + } else { + SimpleQuantifier(consume()) + } + base = RegexRepetition(base, quantifier) + } + base + } + + private def parseBase(): RegexAST = { + consume() match { + case '(' => + parseGroup() + case '[' => + parseCharacterClass() + case '\\' => + parseEscapedCharacter() + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(pos)) + case other => + RegexChar(other) + } + } + + private def parseGroup(): RegexAST = { + val term = parseTerm(() => peek().contains(')')) + consumeExpected(')') + RegexGroup(term) + } + + private def parseCharacterClass(): RegexCharacterClass = { + val start = pos + val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) + // loop until the end of the character class or EOF + var characterClassComplete = false + while (!eof() && !characterClassComplete) { + val ch = consume() + ch match { + case '[' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case ']' => + characterClassComplete = true + case '^' if pos == start + 1 => + // Negates the character class, causing it to match a single character not listed in + // the character class. Only valid immediately after the opening '[' + characterClass.negated = true + case '\n' | '\r' | '\t' | '\b' | '\f' | '\u0007' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case '\\' => + peek() match { + case None => + throw new RegexUnsupportedException( + s"unexpected EOF while parsing escaped character", Some(pos)) + case Some(ch) => + ch match { + case '\\' | '^' | '-' | ']' | '+' => + // escaped metacharacter within character class + characterClass.appendEscaped(consumeExpected(ch)) + } + } + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(pos)) + case _ => + // check for range + val start = ch + peek() match { + case Some('-') => + consumeExpected('-') + peek() match { + case Some(']') => + // '-' at end of class e.g. "[abc-]" + characterClass.append(ch) + characterClass.append('-') + case Some(end) => + skip() + characterClass.appendRange(start, end) + } + case _ => + // treat as supported literal character + characterClass.append(ch) + } + } + } + if (!characterClassComplete) { + throw new RegexUnsupportedException( + s"unexpected EOF while parsing character class", Some(pos)) + } + characterClass + } + + + /** + * Parse a quantifier in one of the following formats: + * + * {n} + * {n,} + * {n,m} (only valid if m >= n) + */ + private def parseQuantifierOrLiteralBrace(): RegexAST = { + + // assumes that '{' has already been consumed + val start = pos + + def treatAsLiteralBrace() = { + // this was not a quantifier, just a literal '{' + pos = start + 1 + RegexChar('{') + } + + consumeInt match { + case Some(minLength) => + peek() match { + case Some(',') => + consumeExpected(',') + val max = consumeInt() + if (peek().contains('}')) { + consumeExpected('}') + max match { + case None => + QuantifierVariableLength(minLength, None) + case Some(m) => + if (m >= minLength) { + QuantifierVariableLength(minLength, max) + } else { + treatAsLiteralBrace() + } + } + } else { + treatAsLiteralBrace() + } + case Some('}') => + consumeExpected('}') + QuantifierFixedLength(minLength) + case _ => + treatAsLiteralBrace() + } + case None => + treatAsLiteralBrace() + } + } + + private def parseEscapedCharacter(): RegexAST = { + peek() match { + case None => + throw new RegexUnsupportedException("escape at end of string", Some(pos)) + case Some(ch) => + ch match { + case 'A' | 'Z' => + // BOL / EOL anchors + consumeExpected(ch) + RegexEscaped(ch) + case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => + // meta sequences + consumeExpected(ch) + RegexEscaped(ch) + case 'B' | 'b' => + // word boundaries + consumeExpected(ch) + RegexEscaped(ch) + case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => + // escaped metacharacter + consumeExpected(ch) + RegexEscaped(ch) + case 'x' => + consumeExpected(ch) + parseHexDigit + case _ if Character.isDigit(ch) => + parseOctalDigit + case other => + throw new RegexUnsupportedException( + s"invalid or unsupported escape character '$other'", Some(pos - 1)) + } + } + } + + private def isHexDigit(ch: Char): Boolean = ch.isDigit || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') + + private def parseHexDigit: RegexHexDigit = { + // \xhh The character with hexadecimal value 0xhh + // \x{h...h} The character with hexadecimal value 0xh...h + // (Character.MIN_CODE_POINT <= 0xh...h <= Character.MAX_CODE_POINT) + + val start = pos + while (!eof() && isHexDigit(pattern.charAt(pos))) { + pos += 1 + } + val hexDigit = pattern.substring(start, pos) + + if (hexDigit.length < 2) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + val value = Integer.parseInt(hexDigit, 16) + if (value < Character.MIN_CODE_POINT || value > Character.MAX_CODE_POINT) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + RegexHexDigit(hexDigit) + } + + private def isOctalDigit(ch: Char): Boolean = ch >= '0' && ch <= '7' + + private def parseOctalDigit: RegexOctalChar = { + // \0n The character with octal value 0n (0 <= n <= 7) + // \0nn The character with octal value 0nn (0 <= n <= 7) + // \0mnn The character with octal value 0mnn (0 <= m <= 3, 0 <= n <= 7) + + def parseOctalDigits(n: Integer): RegexOctalChar = { + val octal = pattern.substring(pos, pos + n) + pos += n + RegexOctalChar(octal) + } + + if (!eof() && isOctalDigit(pattern.charAt(pos))) { + if (pos + 1 < pattern.length && isOctalDigit(pattern.charAt(pos + 1))) { + if (pos + 2 < pattern.length && isOctalDigit(pattern.charAt(pos + 2)) + && pattern.charAt(pos) <= '3') { + parseOctalDigits(3) + } else { + parseOctalDigits(2) + } + } else { + parseOctalDigits(1) + } + } else { + throw new RegexUnsupportedException( + "Invalid octal digit", Some(pos)) + } + } + + /** Determine if we are at the end of the input */ + private def eof(): Boolean = pos == pattern.length + + /** Advance the index by one */ + private def skip(): Unit = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) + } + pos += 1 + } + + /** Get the next character and advance the index by one */ + private def consume(): Char = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) + } else { + pos += 1 + pattern.charAt(pos - 1) + } + } + + /** Consume the next character if it is the one we expect */ + private def consumeExpected(expected: Char): Char = { + val consumed = consume() + if (consumed != expected) { + throw new RegexUnsupportedException( + s"Expected '$expected' but found '$consumed'", Some(pos-1)) + } + consumed + } + + /** Peek at the next character without consuming it */ + private def peek(): Option[Char] = { + if (eof()) { + None + } else { + Some(pattern.charAt(pos)) + } + } + + private def consumeInt(): Option[Int] = { + val start = pos + while (!eof() && peek().exists(_.isDigit)) { + skip() + } + if (start == pos) { + None + } else { + Some(pattern.substring(start, pos).toInt) + } + } + +} + +/** + * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception + * if this is not possible. + */ +class CudfRegexTranspiler { + + val nothingToRepeat = "nothing to repeat" + + def transpile(pattern: String): String = { + // parse the source regular expression + val regex = new RegexParser(pattern).parse() + // validate that the regex is supported by cuDF + validate(regex) + // write out to regex string, performing minor transformations + // such as adding additional escaping + regex.toRegexString + } + + private def validate(regex: RegexAST): Unit = { + regex match { + case RegexOctalChar(_) => + // cuDF produced different results compared to Spark in some cases + // example: "a\141|.$" + throw new RegexUnsupportedException( + s"cuDF does not support octal digits consistently with Spark") + case RegexEscaped(ch) if ch == 'b' || ch == 'B' => + // example: "a\Bb" + // this needs further analysis to determine why words boundaries behave + // differently between Java and cuDF + throw new RegexUnsupportedException("word boundaries are not supported") + case RegexSequence(parts) if parts.isEmpty => + // examples: "", "()", "a|", "|b" + throw new RegexUnsupportedException("empty sequence not supported") + case RegexRepetition(RegexEscaped(_), _) => + // example: "\B?" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexRepetition(RegexChar(a), _) if "$^".contains(a) => + // example: "$*" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexRepetition(RegexRepetition(_, _), _) => + // example: "a*+" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexChoice(l, r) => + (l, r) match { + // check for empty left-hand side caused by ^ or $ or a repetition + case (RegexSequence(a), _) => + a.lastOption match { + case None => + // example: "|a" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + // example: "^|a" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexRepetition(_, _)) => + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + // check for empty right-hand side caused by ^ or $ + case (_, RegexSequence(b)) => + b.headOption match { + case None => + // example: "|b" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + // example: "a|$" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + case (RegexRepetition(_, _), _) => + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + + case RegexSequence(parts) => + if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { + // examples: "a|", "|b" + throw new RegexUnsupportedException(nothingToRepeat) + } + if (isRegexChar(parts.head, '{')) { + // example: "{" + // cuDF would treat this as a quantifier even though in this + // context (being at the start of a sequence) it is not quantifying anything + // note that we could choose to escape this in the transpiler rather than + // falling back to CPU + throw new RegexUnsupportedException(nothingToRepeat) + } + case RegexCharacterClass(_, characters) => + characters.foreach { + case RegexChar(ch) if ch == '[' || ch == ']' => + // examples: + // - "[a[]" should match the literal characters "a" and "[" + // - "[a-b[c-d]]" is supported by Java but not cuDF + throw new RegexUnsupportedException("nested character classes are not supported") + case _ => + } + + case _ => + } + + // walk down the tree and validate children + regex.children().foreach(validate) + } + + private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match { + case RegexChar(ch) => ch == value + case _ => false + } +} + +sealed trait RegexAST { + def children(): Seq[RegexAST] + def toRegexString: String +} + +sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { + override def children(): Seq[RegexAST] = parts + override def toRegexString: String = parts.map(_.toRegexString).mkString +} + +sealed case class RegexGroup(term: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(term) + override def toRegexString: String = s"(${term.toRegexString})" +} + +sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a, b) + override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}" +} + +sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a) + override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}" +} + +sealed trait RegexQuantifier extends RegexAST + +sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = ch.toString +} + +sealed case class QuantifierFixedLength(length: Int) + extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + s"{$length}" + } +} + +sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int]) + extends RegexQuantifier{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + maxLength match { + case Some(max) => + s"{$minLength,$max}" + case _ => + s"{$minLength,}" + } + } +} + +sealed trait RegexCharacterClassComponent extends RegexAST + +sealed case class RegexHexDigit(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\x$a" +} + +sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$a" +} + +sealed case class RegexUnicodeChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\u$a" +} + +sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexCharacterRange(start: Char, end: Char) + extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$start-$end" +} + +sealed case class RegexCharacterClass( + var negated: Boolean, + var characters: ListBuffer[RegexCharacterClassComponent]) + extends RegexAST { + + override def children(): Seq[RegexAST] = characters + def append(ch: Char): Unit = { + characters += RegexChar(ch) + } + + def appendEscaped(ch: Char): Unit = { + characters += RegexEscaped(ch) + } + + def appendRange(start: Char, end: Char): Unit = { + characters += RegexCharacterRange(start, end) + } + + override def toRegexString: String = { + val builder = new StringBuilder("[") + if (negated) { + builder.append("^") + } + for (a <- characters) { + a match { + case RegexChar(ch) if requiresEscaping(ch) => + // cuDF has stricter escaping requirements for certain characters + // within a character class compared to Java or Python regex + builder.append(s"\\$ch") + case other => + builder.append(other.toRegexString) + } + } + builder.append("]") + builder.toString() + } + + private def requiresEscaping(ch: Char): Boolean = { + // there are likely other cases that we will need to add here but this + // covers everything we have seen so far during fuzzing + ch match { + case '-' => + // cuDF requires '-' to be escaped when used as a character within a character + // to disambiguate from the character range syntax 'a-b' + true + case _ => + false + } + } +} + +class RegexUnsupportedException(message: String, index: Option[Int] = None) + extends SQLException { + override def getMessage: String = { + index match { + case Some(i) => s"$message at index $index" + case _ => message + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index def23bdfc88..bc413b48e58 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -753,10 +753,12 @@ class GpuRLikeMeta( override def tagExprForGpu(): Unit = { expr.right match { case Literal(str: UTF8String, _) => - if (str.toString.contains("\u0000")) { - // see https://github.com/NVIDIA/spark-rapids/issues/3962 - willNotWorkOnGpu("The GPU implementation of RLike does not " + - "support null characters in the pattern") + try { + // verify that we support this regex and can transpile it to cuDF format + new CudfRegexTranspiler().transpile(str.toString) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) } case _ => willNotWorkOnGpu(s"RLike with non-literal pattern is not supported on GPU") @@ -787,7 +789,14 @@ case class GpuRLike(left: Expression, right: Expression) throw new IllegalStateException("Really should not be here, " + "Cannot have an invalid scalar value as right side operand in RLike") } - lhs.getBase.containsRe(pattern) + try { + val cudfRegex = new CudfRegexTranspiler().transpile(pattern) + lhs.getBase.containsRe(cudfRegex) + } catch { + case _: RegexUnsupportedException => + throw new IllegalStateException("Really should not be here, " + + "regular expression should have been verified during tagging") + } } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala new file mode 100644 index 00000000000..deea53e7900 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import scala.collection.mutable.ListBuffer + +import org.scalatest.FunSuite + +class RegularExpressionParserSuite extends FunSuite { + + test("simple quantifier") { + assert(parse("a{1}") === + RegexSequence(ListBuffer( + RegexRepetition(RegexChar('a'), QuantifierFixedLength(1))))) + } + + test("not a quantifier") { + assert(parse("{1}") === + RegexSequence(ListBuffer( + RegexChar('{'), RegexChar('1'),RegexChar('}')))) + } + + test("nested repetition") { + assert(parse("a*+") === + RegexSequence(ListBuffer( + RegexRepetition( + RegexRepetition(RegexChar('a'), SimpleQuantifier('*')), + SimpleQuantifier('+'))))) + } + + test("choice") { + assert(parse("a|b") === + RegexChoice(RegexSequence(ListBuffer(RegexChar('a'))), + RegexSequence(ListBuffer(RegexChar('b'))))) + } + + test("group") { + assert(parse("(a)(b)") === + RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer(RegexChar('a')))), + RegexGroup(RegexSequence(ListBuffer(RegexChar('b'))))))) + } + + test("character class") { + assert(parse("[a-z+A-Z]") === + RegexSequence(ListBuffer( + RegexCharacterClass(negated = false, + ListBuffer( + RegexCharacterRange('a', 'z'), + RegexChar('+'), + RegexCharacterRange('A', 'Z')))))) + } + + test("hex digit") { + assert(parse(raw"\xFF") === + RegexSequence(ListBuffer(RegexHexDigit("FF")))) + } + + test("octal digit") { + val digits = Seq("1", "76", "123", "377") + for (digit <- digits) { + assert(parse(raw"\$digit") === + RegexSequence(ListBuffer(RegexOctalChar(digit)))) + } + + // parsing of the octal digit should terminate after parsing "\1" + assert(parse(raw"\18") === + RegexSequence(ListBuffer(RegexOctalChar("1"), RegexChar('8')))) + + // parsing of the octal digit should terminate after parsing "\47" + assert(parse(raw"\477") === + RegexSequence(ListBuffer(RegexOctalChar("47"), RegexChar('7')))) + } + + test("complex expression") { + val ast = parse( + "^" + // start of line + "[+\\-]?" + // optional + or - at start of string + "(" + + "(" + + "(" + + "([0-9]+)|" + // digits, OR + "([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR + "([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing + ")" + + "([eE][+\\-]?[0-9]+)?" + // exponent + "[fFdD]?" + // floating-point designator + ")" + + "|Inf" + // Infinity + "|[nN][aA][nN]" + // NaN + ")" + + "$" // end of line + ) + assert(ast === + RegexSequence(ListBuffer( + RegexChar('^'), + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexChar('+'), RegexEscaped('-'))), SimpleQuantifier('?')), + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + RegexChar('|'), + RegexGroup(RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')), + RegexEscaped('.'), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + RegexChar('|'), + RegexGroup(RegexSequence(ListBuffer(RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+')), + RegexEscaped('.'), + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')))))))), + RegexRepetition(RegexGroup(RegexSequence(ListBuffer( + RegexCharacterClass(negated = false, + ListBuffer(RegexChar('e'), + RegexChar('E'))), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexChar('+'), RegexEscaped('-'))),SimpleQuantifier('?')), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + SimpleQuantifier('?')), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexChar('f'), RegexChar('F'), + RegexChar('d'), RegexChar('D'))),SimpleQuantifier('?'))))), + RegexChar('|'), RegexChar('I'), RegexChar('n'), RegexChar('f'), RegexChar('|'), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N'))), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('a'), RegexChar('A'))), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N'))))) + ), + RegexChar('$')))) + } + + private def parse(pattern: String): RegexAST = { + new RegexParser(pattern).parse() + } + +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala new file mode 100644 index 00000000000..c7f13ee832f --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.util.regex.Pattern + +import scala.collection.mutable.ListBuffer +import scala.util.{Random, Try} + +import ai.rapids.cudf.{ColumnVector, CudfException} +import org.scalatest.FunSuite + +class RegularExpressionTranspilerSuite extends FunSuite with Arm { + + test("cuDF does not support choice with nothing to repeat") { + val patterns = Seq("b+|^\t") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF unsupported choice cases") { + val input = Seq("cat", "dog") + val patterns = Seq("c*|d*", "c*|dog", "[cat]{3}|dog") + patterns.foreach(pattern => { + val e = intercept[CudfException] { + gpuContains(pattern, input) + } + assert(e.getMessage.contains("invalid regex pattern: nothing to repeat")) + }) + } + + test("sanity check: choice edge case 2") { + assertThrows[CudfException] { + gpuContains("c+|d+", Seq("cat", "dog")) + } + } + + test("cuDF does not support possessive quantifier") { + val patterns = Seq("a*+", "a|(a?|a*+)") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF does not support empty sequence") { + val patterns = Seq("", "a|", "()") + patterns.foreach(pattern => + assertUnsupported(pattern, "empty sequence not supported") + ) + } + + test("cuDF does not support quantifier syntax when not quantifying anything") { + // note that we could choose to transpile and escape the '{' and '}' characters + val patterns = Seq("{1,2}", "{1,}", "{1}", "{2,1}") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF does not support OR at BOL / EOL") { + val patterns = Seq("$|a", "^|a") + patterns.foreach(pattern => { + assertUnsupported(pattern, "nothing to repeat") + }) + } + + test("cuDF does not support null in pattern") { + val patterns = Seq("\u0000", "a\u0000b", "a(\u0000)b", "a[a-b][\u0000]") + patterns.foreach(pattern => + assertUnsupported(pattern, "cuDF does not support null characters in regular expressions")) + } + + test("nothing to repeat") { + val patterns = Seq("$*", "^+") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat")) + } + + ignore("known issue - multiline difference between CPU and GPU") { + // see https://github.com/rapidsai/cudf/issues/9620 + val pattern = "2$" + // this matches "2" but not "2\n" on the GPU + assertCpuGpuContainsMatches(Seq(pattern), Seq("2", "2\n", "2\r", "\2\r\n")) + } + + ignore("known issue - dot matches CR on GPU but not on CPU") { + // see https://github.com/rapidsai/cudf/issues/9619 + val pattern = "1." + // '.' matches '\r' on GPU but not on CPU + assertCpuGpuContainsMatches(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2")) + } + + ignore("known issue - octal digit") { + val pattern = "a\\141|.$" // using hex works fine e.g. "a\\x61|.$" + assertCpuGpuContainsMatches(Seq(pattern), Seq("] b[")) + } + + test("character class with ranges") { + val patterns = Seq("[a-b]", "[a-zA-Z]") + patterns.foreach(parse) + } + + test("character class mixed") { + val patterns = Seq("[a-b]", "[a+b]", "ab[cFef-g][^cat]") + patterns.foreach(parse) + } + + test("transpile character class unescaped range symbol") { + val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") + val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", raw"a[^\-]") + val transpiler = new CudfRegexTranspiler() + val transpiled = patterns.map(transpiler.transpile) + assert(transpiled === expected) + } + + test("transpile complex regex 1") { + val VALID_FLOAT_REGEX = + "^" + // start of line + "[+\\-]?" + // optional + or - at start of string + "(" + + "(" + + "(" + + "([0-9]+)|" + // digits, OR + "([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR + "([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing + ")" + + "([eE][+\\-]?[0-9]+)?" + // exponent + "[fFdD]?" + // floating-point designator + ")" + + "|Inf" + // Infinity + "|[nN][aA][nN]" + // NaN + ")" + + "$" // end of line + + // input and output should be identical + doTranspileTest(VALID_FLOAT_REGEX, VALID_FLOAT_REGEX) + } + + test("transpile complex regex 2") { + val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " + + "[0-9]{2}:[0-9]{2}:[0-9]{2})" + + "(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$" + + // input and output should be identical + doTranspileTest(TIMESTAMP_TRUNCATE_REGEX, TIMESTAMP_TRUNCATE_REGEX) + + } + + test("compare CPU and GPU: character range including unescaped + and -") { + val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]") + val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: character range including escaped + and -") { + val patterns = Seq(raw"a[\-\+]", raw"a[\+\-]", raw"a[a-b\-]") + val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: hex") { + val patterns = Seq(raw"\x61") + val inputs = Seq("a", "b") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: octal") { + val patterns = Seq("\\\\141") + val inputs = Seq("a", "b") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: fuzz test with limited chars") { + // testing with this limited set of characters finds issues much + // faster than using the full ASCII set + // CR and LF has been excluded due to known issues + doFuzzTest(Some("|()[]{},.^$*+?abc123x\\ \tB")) + } + + test("compare CPU and GPU: fuzz test printable ASCII chars plus TAB") { + // CR and LF has been excluded due to known issues + doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\t")) + } + + test("compare CPU and GPU: fuzz test ASCII chars") { + // CR and LF has been excluded due to known issues + val chars = (0x00 to 0x7F) + .map(_.toChar) + .filterNot(_ == '\n') + .filterNot(_ == '\r') + doFuzzTest(Some(chars.mkString)) + } + + ignore("compare CPU and GPU: fuzz test all chars") { + // this test cannot be enabled until we support CR and LF + doFuzzTest(None) + } + + private def doFuzzTest(validChars: Option[String]) { + + val r = new EnhancedRandom(new Random(seed = 0L), + options = FuzzerOptions(validChars, maxStringLen = 12)) + + val data = Range(0, 1000).map(_ => r.nextString()) + + // generate patterns that are valid on both CPU and GPU + val patterns = ListBuffer[String]() + while (patterns.length < 5000) { + val pattern = r.nextString() + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern)).isSuccess) { + patterns += pattern + } + } + + assertCpuGpuContainsMatches(patterns, data) + } + + private def assertCpuGpuContainsMatches(javaPatterns: Seq[String], input: Seq[String]) = { + for (javaPattern <- javaPatterns) { + val cpu = cpuContains(javaPattern, input) + val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) + val gpu = try { + gpuContains(cudfPattern, input) + } catch { + case e: CudfException => + fail(s"cuDF failed to compile pattern: $cudfPattern", e) + } + for (i <- input.indices) { + if (cpu(i) != gpu(i)) { + fail(s"javaPattern=${toReadableString(javaPattern)}, " + + s"cudfPattern=${toReadableString(cudfPattern)}, " + + s"input='${toReadableString(input(i))}', " + + s"cpu=${cpu(i)}, gpu=${gpu(i)}") + } + } + } + } + + /** cuDF containsRe helper */ + private def gpuContains(cudfPattern: String, input: Seq[String]): Array[Boolean] = { + val result = new Array[Boolean](input.length) + withResource(ColumnVector.fromStrings(input: _*)) { cv => + withResource(cv.containsRe(cudfPattern)) { c => + withResource(c.copyToHost()) { hv => + result.indices.foreach(i => result(i) = hv.getBoolean(i)) + } + } + } + result + } + + private def toReadableString(x: String): String = { + x.map { + case '\r' => "\\r" + case '\n' => "\\n" + case '\t' => "\\t" + case other => other + }.mkString + } + + private def cpuContains(pattern: String, input: Seq[String]): Array[Boolean] = { + val p = Pattern.compile(pattern) + input.map(s => p.matcher(s).find(0)).toArray + } + + private def doTranspileTest(pattern: String, expected: String) { + val transpiled: String = transpile(pattern) + assert(transpiled === expected) + } + + private def transpile(pattern: String): String = { + new CudfRegexTranspiler().transpile(pattern) + } + + private def assertUnsupported(pattern: String, message: String): Unit = { + val e = intercept[RegexUnsupportedException] { + transpile(pattern) + } + assert(e.getMessage.startsWith(message)) + } + + private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse() + +}