From 1ebbee734ada45fdd6e4eba30cc0c7f5619ddda4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 13:24:30 -0600 Subject: [PATCH 01/14] Implement regexp parser to detect when we need to fall back to CPU for RLIKE Signed-off-by: Andy Grove --- docs/compatibility.md | 96 ++- .../src/main/python/string_test.py | 24 +- .../spark/sql/rapids/stringFunctions.scala | 593 +++++++++++++++++- .../rapids/RegularExpressionParserSuite.scala | 54 ++ .../RegularExpressionTranspilerSuite.scala | 292 +++++++++ 5 files changed, 992 insertions(+), 67 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala diff --git a/docs/compatibility.md b/docs/compatibility.md index 17e642f344c..c6a27d2af73 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -265,6 +265,48 @@ If a null char '\0' is in a string that is being matched by a regular expression the end of the string. This will be fixed in a future release. The issue is [here](https://github.com/NVIDIA/spark-rapids/issues/119). +### RLike + +The GPU implementation of RLike has a number of known issues where behavior is not consistent with Apache Spark and +this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. + +A summary of known issues is shown below but this is not intended to be a comprehensive list. We recommend that you +do your own testing to verify whether the GPU implementation of `RLike` is suitable for your use case. + +We plan on improving the RLike functionality over time to make it more compatible with Spark so this feature should +be used at your own risk with the expectation that the behavior will change in future releases. + + +#### Null support + +The GPU implementation of RLike supports null characters in the input but does not support null characters in +the regular expression and will fall back to the CPU in this case. + +#### Qualifiers with nothing to repeat + +Spark supports qualifiers in cases where there is nothing to repeat. For example, Spark supports `a*+` and this +will match all inputs. The GPU implementation of RLike does not support this syntax and will throw an exception with +the message `nothing to repeat at position 0`. + +#### Stricter escaping requirements + +The GPU implementation of RLike has stricter requirements around escaping special characters in some cases. + +| Pattern | Input | Spark on CPU | Spark on GPU | +|-----------|--------|--------------|--------------| +| `a[-+]` | `a-` | Match | No Match | +| `a[\-\+]` | `a-` | Match | Match | + +#### Empty groups + +The GPU implementation of RLike does not support empty groups correctly. + +| Pattern | Input | Spark on CPU | Spark on GPU | +|-----------|--------|--------------|--------------| +| `z()?` | `a` | No Match | Match | +| `z()*` | `a` | No Match | Match | + + ## Timestamps Spark stores timestamps internally relative to the JVM time zone. Converting an arbitrary timestamp @@ -569,60 +611,6 @@ distribution. Because the results are not bit-for-bit identical with the Apache `approximate_percentile`, this feature is disabled by default and can be enabled by setting `spark.rapids.sql.expression.ApproximatePercentile=true`. -## RLike - -The GPU implementation of RLike has a number of known issues where behavior is not consistent with Apache Spark and -this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. - -A summary of known issues is shown below but this is not intended to be a comprehensive list. We recommend that you -do your own testing to verify whether the GPU implementation of `RLike` is suitable for your use case. - -We plan on improving the RLike functionality over time to make it more compatible with Spark so this feature should -be used at your own risk with the expectation that the behavior will change in future releases. - -### Multi-line handling - -The GPU implementation of RLike supports `^` and `$` to represent the start and end of lines within a string but -Spark uses `^` and `$` to refer to the start and end of the entire string (equivalent to `\A` and `\Z`). - -| Pattern | Input | Spark on CPU | Spark on GPU | -|---------|--------|--------------|--------------| -| `^A` | `A\nB` | Match | Match | -| `A$` | `A\nB` | No Match | Match | -| `^B` | `A\nB` | No Match | Match | -| `B$` | `A\nB` | Match | Match | - -As a workaround, `\A` and `\Z` can be used instead of `^` and `$`. - -### Null support - -The GPU implementation of RLike supports null characters in the input but does not support null characters in -the regular expression and will fall back to the CPU in this case. - -### Qualifiers with nothing to repeat - -Spark supports qualifiers in cases where there is nothing to repeat. For example, Spark supports `a*+` and this -will match all inputs. The GPU implementation of RLike does not support this syntax and will throw an exception with -the message `nothing to repeat at position 0`. - -### Stricter escaping requirements - -The GPU implementation of RLike has stricter requirements around escaping special characters in some cases. - -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `a[-+]` | `a-` | Match | No Match | -| `a[\-\+]` | `a-` | Match | Match | - -### Empty groups - -The GPU implementation of RLike does not support empty groups correctly. - -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `z()?` | `a` | No Match | Match | -| `z()*` | `a` | No Match | Match | - ## Conditionals and operations with side effects (ANSI mode) In Apache Spark condition operations like `if`, `coalesce`, and `case/when` lazily evaluate diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 792d870580b..cfd0010d2ef 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -492,7 +492,7 @@ def test_rlike_embedded_null(): conf={'spark.rapids.sql.expression.RLike': 'true'}) @allow_non_gpu('ProjectExec', 'RLike') -def test_rlike_null_pattern(): +def test_rlike_fallback_null_pattern(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( @@ -500,6 +500,15 @@ def test_rlike_null_pattern(): 'RLike', conf={'spark.rapids.sql.expression.RLike': 'true'}) +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_empty_group(): + gen = mk_str_gen('[abcd]{1,3}') + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'a rlike "a()?"'), + 'RLike', + conf={'spark.rapids.sql.expression.RLike': 'true'}) + def test_rlike_escape(): gen = mk_str_gen('[ab]{0,2}[\\-\\+]{0,2}') assert_gpu_and_cpu_are_equal_collect( @@ -507,7 +516,6 @@ def test_rlike_escape(): 'a rlike "a[\\\\-]"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF supports multiline by default but Spark does not - https://github.com/rapidsai/cudf/issues/9439') def test_rlike_multi_line(): gen = mk_str_gen('[abc]\n[def]') assert_gpu_and_cpu_are_equal_collect( @@ -518,18 +526,20 @@ def test_rlike_multi_line(): 'a rlike "e$"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF has stricter requirements around escaping - https://github.com/rapidsai/cudf/issues/9434') def test_rlike_missing_escape(): gen = mk_str_gen('a[\\-\\+]') assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, gen).selectExpr( - 'a rlike "a[-]"'), + 'a rlike "a[-]"', + 'a rlike "a[+-]"', + 'a rlike "a[a-b-]"'), conf={'spark.rapids.sql.expression.RLike': 'true'}) -@pytest.mark.xfail(reason='cuDF does not support qualifier with nothing to repeat - https://github.com/rapidsai/cudf/issues/9434') -def test_rlike_nothing_to_repeat(): +@allow_non_gpu('ProjectExec', 'RLike') +def test_rlike_fallback_possessive_quantifier(): gen = mk_str_gen('(\u20ac|\\w){0,3}a[|b*.$\r\n]{0,2}c\\w{0,3}') - assert_gpu_and_cpu_are_equal_collect( + assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, gen).selectExpr( 'a rlike "a*+"'), + 'RLike', conf={'spark.rapids.sql.expression.RLike': 'true'}) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index def23bdfc88..456544e41d1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -16,7 +16,9 @@ package org.apache.spark.sql.rapids -import scala.collection.mutable.ArrayBuffer +import java.sql.SQLException + +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import ai.rapids.cudf.{ColumnVector, ColumnView, DType, PadSide, Scalar, Table} import com.nvidia.spark.rapids._ @@ -753,10 +755,12 @@ class GpuRLikeMeta( override def tagExprForGpu(): Unit = { expr.right match { case Literal(str: UTF8String, _) => - if (str.toString.contains("\u0000")) { - // see https://github.com/NVIDIA/spark-rapids/issues/3962 - willNotWorkOnGpu("The GPU implementation of RLike does not " + - "support null characters in the pattern") + try { + // verify that we support this regex and can transpile it to cuDF format + new CudfRegexTranspiler().transpile(str.toString) + } catch { + case e: RegexUnsupportedException => + willNotWorkOnGpu(e.getMessage) } case _ => willNotWorkOnGpu(s"RLike with non-literal pattern is not supported on GPU") @@ -767,6 +771,576 @@ class GpuRLikeMeta( GpuRLike(lhs, rhs) } +/** + * Regular expression parser. + * + * Suggested reading before making changes to this code: + * + * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ + * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ + */ +class RegexParser(pattern: String) { + + /** index of current position within the string being parsed */ + private var i = 0 + + def parse(): RegexAST = { + val ast = parseInternal() + if (!eof()) { + throw new RegexUnsupportedException("failed to parse full regex") + } + ast + } + + private def parseInternal(): RegexAST = { + val term = parseTerm(() => peek().contains('|')) + if (!eof() && peek().contains('|')) { + consumeExpected('|') + RegexChoice(term, parseInternal()) + } else { + term + } + } + + private def parseTerm(until: () => Boolean): RegexAST = { + val sequence = RegexSequence(new ListBuffer()) + while (!eof() && !until()) { + parseFactor() match { + case RegexSequence(parts) => + sequence.parts ++= parts + case other => + sequence.parts += other + } + } + sequence + } + + private def isValidQuantifierAhead(): Boolean = { + if (peek().contains('{')) { + val bookmark = i + consumeExpected('{') + val q = parseQuantifierOrLiteralBrace() + i = bookmark + q match { + case _: QuantifierFixedLength | _: QuantifierVariableLength => true + case _ => false + } + } else { + false + } + } + + private def parseFactor(): RegexAST = { + // TODO rewrite this + var base = parseBase() + while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') + || isValidQuantifierAhead())) { + + if (peek().contains('{')) { + consumeExpected('{') + base = RegexRepetition(base, parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier]) + } else { + base = RegexRepetition(base, SimpleQuantifier(consume())) + } + } + base + } + + private def parseBase(): RegexAST = { + consume() match { + case '(' => + parseGroup() + case '[' => + parseCharacterClass() + case '\\' => + parseEscapedCharacter() + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(i)) + case other => + RegexChar(other) + } + } + + private def parseGroup(): RegexAST = { + val term = parseTerm(() => peek().contains(')')) + consumeExpected(')') + RegexGroup(term) + } + + /** + * Parse a character class as defined in the cuDF documentation at + * https://docs.rapids.ai/api/libcudf/stable/md_regex.html + */ + private def parseCharacterClass(): RegexCharacterClass = { + val start = i + val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) + // loop until the end of the character class or EOF + var characterClassComplete = false + while (!eof() && !characterClassComplete) { + val ch = consume() + ch match { + case '[' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case ']' => + characterClassComplete = true + case '^' if i == start + 1 => + // Negates the character class, causing it to match a single character not listed in + // the character class. Only valid immediately after the opening '[' + characterClass.negated = true + case '\n' | '\r' | '\t' | '\b' | '\f' => + // TODO add \a here as well + // Add this special character to the character class + characterClass.append(ch) + case '\\' => + peek() match { + case None => + throw new RegexUnsupportedException( + s"unexpected EOF while parsing escaped character", Some(i)) + case Some(ch) => + ch match { + case '\\' | '^' | '-' | ']' | '+' => + // escaped metacharacter within character class + characterClass.appendEscaped(consumeExpected(ch)) + } + } + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(i)) + case _ => + // check for range + val start = ch + peek() match { + case Some('-') => + consumeExpected('-') + // TODO look for other non valid chars like escape + peek() match { + case Some(']') => + characterClass.append(ch) + characterClass.append('-') + case Some(end) => + skip() + characterClass.appendRange(start, end) + } + case _ => + // treat as supported literal character + characterClass.append(ch) + } + } + } + if (!characterClassComplete) { + throw new RegexUnsupportedException( + s"unexpected EOF while parsing character class", Some(i)) + } + characterClass + } + + + /** + * Parse a quantifier in one of the following formats: + * + * {n} + * {n,} + * {n,m} (only valid if m >= n) + */ + private def parseQuantifierOrLiteralBrace(): RegexAST = { + //TODO refactor to avoid code duplication + // assumes that '{' has already been consumed + val start = i + consumeInt match { + case None => + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + case Some(minLength) => + peek() match { + case Some(',') => + consumeExpected(',') + val max = consumeInt() + if (peek().contains('}')) { + // end of quantifier + consumeExpected('}') + max match { + case None => + QuantifierVariableLength(minLength, None) + case Some(m) => + if (m >= minLength) { + QuantifierVariableLength(minLength, max) + } else { + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + } + } + } else { + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + } + case Some('}') => + // end of quantifier + consumeExpected('}') + QuantifierFixedLength(minLength) + case _ => + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + } + } + } + + private def parseEscapedCharacter(): RegexAST = { + peek() match { + case None => + throw new RegexUnsupportedException("escape at end of string", Some(i)) + case Some(ch) => + consumeExpected(ch) + ch match { + case 'A' | 'Z' => + // anchors BOL / EOL + RegexEscaped(ch) + case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => + // meta sequences + RegexEscaped(ch) + case 'B' | 'b' => + // word boundaries + RegexEscaped(ch) + case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => + // escaped metacharacter + RegexEscaped(ch) + case 'x' => + parseHexDigit + case _ if ch.isDigit => + parseOctalDigit + case other => + //TODO handle this in transpiler not in parser + throw new RegexUnsupportedException( + s"invalid or unsupported escape character '$other'", Some(i - 1)) + } + } + } + + private def parseHexDigit: RegexAST = { + + def isHexDigit(ch: Char): Boolean = ch.isDigit || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') + + if (i + 1 < pattern.length + && isHexDigit(pattern.charAt(i)) + && isHexDigit(pattern.charAt(i+1))) { + val hex = pattern.substring(i, i + 2) + i += 2 + RegexHexChar(hex) + } else { + throw new RegexUnsupportedException( + "Invalid hex digit", Some(i)) + } + } + + private def parseOctalDigit = { + + if (i + 2 < pattern.length + && pattern.charAt(i).isDigit + && pattern.charAt(i+1).isDigit + && pattern.charAt(i+2).isDigit) { + val hex = pattern.substring(i, i + 2) + i += 3 + RegexOctalChar(hex) + } else { + throw new RegexUnsupportedException( + "Invalid octal digit", Some(i)) + } + } + + /** Determine if we are at the end of the input */ + private def eof(): Boolean = i == pattern.length + + /** Advance the index by one */ + private def skip(): Unit = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + } + i += 1 + } + + /** Get the next character and advance the index by one */ + private def consume(): Char = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + } else { + i += 1 + pattern.charAt(i - 1) + } + } + + /** Consume the next character if it is the one we expect */ + private def consumeExpected(expected: Char): Char = { + val consumed = consume() + if (consumed != expected) { + throw new RegexUnsupportedException( + s"Expected '$expected' but found '$consumed'", Some(i-1)) + } + consumed + } + + /** Peek at the next character without consuming it */ + private def peek(): Option[Char] = { + if (eof()) { + None + } else { + Some(pattern.charAt(i)) + } + } + + private def consumeInt(): Option[Int] = { + val start = i + while (!eof() && peek().exists(_.isDigit)) { + skip() + } + if (start == i) { + None + } else { + Some(pattern.substring(start, i).toInt) + } + } + +} + +/** + * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception + * if this is not possible. + */ +class CudfRegexTranspiler { + + def transpile(pattern: String): String = { + // parse the source regular expression + val regex = new RegexParser(pattern).parse() + // validatge that the regex is supported by cuDF + validate(regex) + // write out to regex string, performing minor transformations + regex.toRegexString + } + + private def validate(regex: RegexAST): Unit = { + regex match { + case RegexGroup(RegexSequence(parts)) => + if (parts.isEmpty) { + throw new RegexUnsupportedException("cuDF does not support empty groups") + } + val s = parts.map(_.toRegexString).mkString + if (s.startsWith("?") && !s.startsWith("?:")) { + throw new RegexUnsupportedException("nothing to repeat") + } + case RegexRepetition(RegexEscaped(_), _) => + throw new RegexUnsupportedException("nothing to repeat") + case RegexRepetition(RegexChar(a), _) if "$^.".contains(a) => + throw new RegexUnsupportedException("nothing to repeat") + case RegexRepetition(RegexRepetition(_, _), _) => + // cuDF generally does not support nested repetitions such as possessive quantifiers ("a*+") + throw new RegexUnsupportedException("nothing to repeat") + case RegexChoice(l, r) => + (l, r) match { + // check for empty left-hand side caused by ^ or $ or a repetition + case (RegexSequence(a), _) => + a.lastOption match { + case None => + throw new RegexUnsupportedException("nothing to repeat") + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + throw new RegexUnsupportedException("nothing to repeat") + case Some(RegexRepetition(_, _)) => + throw new RegexUnsupportedException("nothing to repeat") + case _ => + } + // check for empty right-hand side caused by ^ or $ + case (_, RegexSequence(b)) => + b.headOption match { + case None => + throw new RegexUnsupportedException("nothing to repeat") + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + throw new RegexUnsupportedException("nothing to repeat") + case _ => + } + case (RegexRepetition(_, _), _) => + throw new RegexUnsupportedException("nothing to repeat") + case _ => + } + + //TODO do not use toRegexString here + case RegexSequence(parts) => + if (parts.head.toRegexString == "|" || parts.last.toRegexString == "|") { + throw new RegexUnsupportedException("nothing to repeat") + } + //TODO this is too hacky + if (parts.head.toRegexString.startsWith("{")) { + // cuDF would treat this as a quantifier even though in this + // context (being at the start of a sequence) it is not quantifying anything + throw new RegexUnsupportedException("nothing to repeat") + } + parts.foreach { + case RegexEscaped(ch) if ch == 'b' || ch == 'B' => + // this needs further analysis to determine why words boundaries behave + // differently betwee Java and cuDF + throw new RegexUnsupportedException("word boundaries are not supported") + case _ => + } + case RegexCharacterClass(negated, characters) => + characters.foreach { + case RegexChar(ch) if ch == '[' || ch == ']' => + // this can have very different semantics between Java and cuDF + throw new RegexUnsupportedException("nested character classes are not supported") + case _ => + } + + case _ => + } + regex.children().foreach(validate) + } +} + +sealed trait RegexAST { + def children(): Seq[RegexAST] + def toRegexString: String +} + +sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { + override def children(): Seq[RegexAST] = parts + override def toRegexString: String = parts.map(_.toRegexString).mkString +} + +sealed case class RegexGroup(term: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(term) + override def toRegexString: String = s"(${term.toRegexString})" +} + +sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a, b) + override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}" +} + +sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a) + override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}" +} + +sealed trait RegexQuantifier extends RegexAST + +sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = ch.toString +} + +sealed case class QuantifierFixedLength(length: Int) + extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + s"{$length}" + } +} + +sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int]) + extends RegexQuantifier{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + maxLength match { + case Some(max) => + s"{$minLength,$max}" + case _ => + s"{$minLength,}" + } + } +} + +sealed trait RegexCharacterClassComponent extends RegexAST + +sealed case class RegexHexChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\x$a" +} + +sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$a" +} + +sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexCharacterRange(start: Char, end: Char) + extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$start-$end" +} + +sealed case class RegexCharacterClass( + var negated: Boolean, + var characters: ListBuffer[RegexCharacterClassComponent]) + extends RegexAST { + + override def children(): Seq[RegexAST] = characters + def append(ch: Char): Unit = { + characters += RegexChar(ch) + } + + def appendEscaped(ch: Char): Unit = { + characters += RegexEscaped(ch) + } + + def appendRange(start: Char, end: Char): Unit = { + characters += RegexCharacterRange(start, end) + } + + override def toRegexString: String = { + val builder = new StringBuilder("[") + if (negated) { + builder.append("^") + } + for (a <- characters) { + a match { + case RegexChar(ch) if requiresEscaping(ch) => + // cuDF has stricter escaping requirements for certain characters + // within a character class compared to Java or Python regex + builder.append(s"\\$ch") + case other => + builder.append(other.toRegexString) + } + } + builder.append("]") + builder.toString() + } + + private def requiresEscaping(ch: Char): Boolean = { + // there are likely other cases that we will need to add here but this + // covers everything we have seen so far during fuzzing + ch match { + case '-' => + // cuDF requires '-' to be escaped when used as a character within a character + // to disambiguate from the character range syntax 'a-b' + true + case _ => + false + } + } +} + +class RegexUnsupportedException(message: String, index: Option[Int] = None) + extends SQLException { + override def getMessage: String = { + index match { + case Some(i) => s"$message at index $index" + case _ => message + } + } +} + case class GpuRLike(left: Expression, right: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { @@ -787,7 +1361,14 @@ case class GpuRLike(left: Expression, right: Expression) throw new IllegalStateException("Really should not be here, " + "Cannot have an invalid scalar value as right side operand in RLike") } - lhs.getBase.containsRe(pattern) + try { + val cudfRegex = new CudfRegexTranspiler().transpile(pattern) + lhs.getBase.containsRe(cudfRegex) + } catch { + case _: RegexUnsupportedException => + throw new IllegalStateException("Really should not be here, " + + "regular expression should have been verified during tagging") + } } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala new file mode 100644 index 00000000000..a2a93e2eff5 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import scala.collection.mutable.ListBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.sql.rapids.{QuantifierFixedLength, RegexAST, RegexChar, RegexParser, RegexRepetition, RegexSequence, SimpleQuantifier} + +class RegularExpressionParserSuite extends FunSuite { + + test("simple quantifier") { + assert(parse("a{1}") === + RegexSequence(ListBuffer( + RegexRepetition(RegexChar('a'), QuantifierFixedLength(1))))) + } + + test("not a quantifier") { + assert(parse("{1}") === + RegexSequence(ListBuffer( + RegexChar('{'), RegexChar('1'),RegexChar('}')))) + } + + test("nested repetition") { + assert(parse("a*+") === + RegexSequence(ListBuffer( + RegexRepetition( + RegexRepetition(RegexChar('a'), SimpleQuantifier('*')), + SimpleQuantifier('+'))))) + } + + test("adhoc") { + println(parse("ab*$}\\B? ")) + } + + private def parse(pattern: String): RegexAST = { + new RegexParser(pattern).parse() + } + +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala new file mode 100644 index 00000000000..1d5723cccd5 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.util.regex.Pattern + +import scala.collection.mutable.ListBuffer +import scala.util.{Random, Try} + +import ai.rapids.cudf.{ColumnVector, CudfException} +import org.scalatest.FunSuite + +import org.apache.spark.sql.rapids.{CudfRegexTranspiler, RegexAST, RegexParser, RegexUnsupportedException} + +class RegularExpressionTranspilerSuite extends FunSuite with Arm { + + test("cuDF does not support choice with nothing to repeat") { + val patterns = Seq("b+|^\t") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF unsupported choice cases") { + val input = Seq("cat", "dog") + val supportedPatterns = Seq("cat|d*") + val patterns = Seq("c*|d*", "c*|dog", "[cat]{3}|dog") + patterns.foreach(pattern => { + val e = intercept[CudfException] { + gpuContains(pattern, input) + } + assert(e.getMessage.contains("invalid regex pattern: nothing to repeat")) + }) + } + + test("sanity check: choice edge case 2") { + assertThrows[CudfException] { + gpuContains("c+|d+", Seq("cat", "dog")) + } + } + + test("cuDF does not support possessive quantifier") { + val patterns = Seq("a*+", "a|(a?|a*+)") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF does not support empty groups") { + assertUnsupported("a()?", "cuDF does not support empty groups") + } + + test("cuDF does not support quantifier syntax when not quantifying anything") { + // note that we could choose to transpile and escape the '{' and '}' characters + val patterns = Seq("{1,2}", "{1,}", "{1}", "{2,1}") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat") + ) + } + + test("cuDF does not support OR at BOL / EOL") { + val patterns = Seq("$|a", "^|a") + // "a|$" ? + // "a|^" ? + patterns.foreach(pattern => { + println(pattern) + assertUnsupported(pattern, "nothing to repeat") + }) + } + + test("cuDF does not support null in pattern") { + val patterns = Seq("\u0000", "a\u0000b", "a(\u0000)b", "a[a-b][\u0000]") + patterns.foreach(pattern => + assertUnsupported(pattern, "cuDF does not support null characters in regular expressions")) + } + + test("nothing to repeat") { + val patterns = Seq("$*", "^+", ".*") + patterns.foreach(pattern => + assertUnsupported(pattern, "nothing to repeat")) + } + + ignore("known issue - multiline difference between CPU and GPU") { + // see https://github.com/rapidsai/cudf/issues/9620 + val pattern = "2$" + // this matches "2" but not "2\n" on the GPU + assertCpuGpuContainsMatches(Seq(pattern), Seq("2", "2\n", "2\r", "\2\r\n")) + } + + ignore("known issue - dot matches CR on GPU but not on CPU") { + // see https://github.com/rapidsai/cudf/issues/9619 + val pattern = "1." + // '.' matches '\r' on GPU but not on CPU + assertCpuGpuContainsMatches(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2")) + + } + + test("character class with ranges") { + val patterns = Seq("[a-b]", "[a-zA-Z]") + patterns.foreach(parse) + } + + test("character class mixed") { + val patterns = Seq("[a-b]", "[a+b]", "ab[cFef-g][^cat]") + patterns.foreach(parse) + } + + test("transpile character class unescaped range symbol") { + val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") + val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", raw"a[^\-]") + val transpiler = new CudfRegexTranspiler() + val transpiled = patterns.map(transpiler.transpile) + assert(transpiled === expected) + } + + test("transpile complex regex 1") { + val VALID_FLOAT_REGEX = + "^" + // start of line + "[+\\-]?" + // optional + or - at start of string + "(" + + "(" + + "(" + + "([0-9]+)|" + // digits, OR + "([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR + "([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing + ")" + + "([eE][+\\-]?[0-9]+)?" + // exponent + "[fFdD]?" + // floating-point designator + ")" + + "|Inf" + // Infinity + "|[nN][aA][nN]" + // NaN + ")" + + "$" // end of line + + // input and output should be identical + doTranspileTest(VALID_FLOAT_REGEX, VALID_FLOAT_REGEX) + } + + test("transpile complex regex 2") { + val TIMESTAMP_TRUNCATE_REGEX = "^([0-9]{4}-[0-9]{2}-[0-9]{2} " + + "[0-9]{2}:[0-9]{2}:[0-9]{2})" + + "(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$" + + // input and output should be identical + doTranspileTest(TIMESTAMP_TRUNCATE_REGEX, TIMESTAMP_TRUNCATE_REGEX) + + } + + test("compare CPU and GPU: character range including unescaped + and -") { + val patterns = Seq("a[-]+", "a[a-b-]+", "a[-a-b]", "a[-+]", "a[+-]") + val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: character range including escaped + and -") { + val patterns = Seq(raw"a[\-\+]", raw"a[\+\-]", raw"a[a-b\-]") + val inputs = Seq("a+", "a-", "a", "a-+", "a[a-b-]") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: hex") { + val patterns = Seq(raw"\x61") + val inputs = Seq("a", "b") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: octal") { + val patterns = Seq("\\\\141") + val inputs = Seq("a", "b") + assertCpuGpuContainsMatches(patterns, inputs) + } + + test("compare CPU and GPU: fuzz test with limited chars") { + // testing with this limited set of characters finds issues much + // faster than using the full ASCII set + // CR and LF has been excluded due to known issues + doFuzzTest(Some("|()[]{},.^$*+?abc123x\\ \tB")) + } + + test("compare CPU and GPU: fuzz test printable ASCII chars plus TAB") { + // CR and LF has been excluded due to known issues + doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\t")) + } + + test("compare CPU and GPU: fuzz test ASCII chars") { + // CR and LF has been excluded due to known issues + val chars = (0x00 to 0x7F) + .map(_.toChar) + .filterNot(_ == '\n') + .filterNot(_ == '\r') + doFuzzTest(Some(chars.mkString)) + } + + ignore("compare CPU and GPU: fuzz test all chars") { + // this test cannot be enabled until we support CR and LF + doFuzzTest(None) + } + + private def doFuzzTest(validChars: Option[String]) { + + val r = new EnhancedRandom(new Random(seed = 0L), + options = FuzzerOptions(validChars, maxStringLen = 12)) + + val data = Range(0, 1000).map(_ => r.nextString()) + + // generate patterns that are valid on both CPU and GPU + val patterns = ListBuffer[String]() + while (patterns.length < 5000) { + val pattern = r.nextString() + if (Try(Pattern.compile(pattern)).isSuccess && Try(transpile(pattern)).isSuccess) { + patterns += pattern + } + } + + assertCpuGpuContainsMatches(patterns, data) + } + + private def assertCpuGpuContainsMatches(javaPatterns: Seq[String], input: Seq[String]) = { + for (javaPattern <- javaPatterns) { + val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) + val cpu = cpuContains(javaPattern, input) + val gpu = gpuContains(cudfPattern, input) + for (i <- input.indices) { + if (cpu(i) != gpu(i)) { + fail(s"javaPattern=${toReadableString(javaPattern)}, " + + s"cudfPattern=${toReadableString(cudfPattern)}, " + + s"input='${toReadableString(input(i))}', " + + s"cpu=${cpu(i)}, gpu=${gpu(i)}") + } + } + } + } + + /** cuDF containsRe helper */ + private def gpuContains(cudfPattern: String, input: Seq[String]): Array[Boolean] = { + val result = new Array[Boolean](input.length) + withResource(ColumnVector.fromStrings(input: _*)) { cv => + withResource(cv.containsRe(cudfPattern)) { c => + withResource(c.copyToHost()) { hv => + result.indices.foreach(i => result(i) = hv.getBoolean(i)) + } + } + } + result + } + + private def toReadableString(x: String): String = { + x.map { + case '\r' => "\\r" + case '\n' => "\\n" + case '\t' => "\\t" + case other => other + }.mkString + } + + private def cpuContains(pattern: String, input: Seq[String]): Array[Boolean] = { + val p = Pattern.compile(pattern) + input.map(s => p.matcher(s).find(0)).toArray + } + + private def doTranspileTest(pattern: String, expected: String) { + val transpiled: String = transpile(pattern) + assert(transpiled === expected) + } + + private def transpile(pattern: String): String = { + new CudfRegexTranspiler().transpile(pattern) + } + + private def assertUnsupported(pattern: String, message: String): Unit = { + val e = intercept[RegexUnsupportedException] { + transpile(pattern) + } + assert(e.getMessage.startsWith(message)) + } + + private def parse(pattern: String): RegexAST = new RegexParser(pattern).parse() + +} From c5fcb28e4d3192748997c26d4c9bc69a4bf0a1b9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 13:52:54 -0600 Subject: [PATCH 02/14] update compatibility docs --- docs/compatibility.md | 54 ++++++++++++------------------------------- 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index c6a27d2af73..95591c185d3 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -257,55 +257,31 @@ The plugin supports reading `uncompressed`, `snappy` and `gzip` Parquet files an fall back to the CPU when reading an unsupported compression format, and will error out in that case. -## Regular Expressions -The RAPIDS Accelerator for Apache Spark currently supports string literal matches, not wildcard -matches. +## LIKE -If a null char '\0' is in a string that is being matched by a regular expression, `LIKE` sees it as +If a null char '\0' is in a string that is being matched by a regular expression, 'LIKE' sees it as the end of the string. This will be fixed in a future release. The issue is [here](https://github.com/NVIDIA/spark-rapids/issues/119). -### RLike - -The GPU implementation of RLike has a number of known issues where behavior is not consistent with Apache Spark and -this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. - -A summary of known issues is shown below but this is not intended to be a comprehensive list. We recommend that you -do your own testing to verify whether the GPU implementation of `RLike` is suitable for your use case. - -We plan on improving the RLike functionality over time to make it more compatible with Spark so this feature should -be used at your own risk with the expectation that the behavior will change in future releases. - - -#### Null support - -The GPU implementation of RLike supports null characters in the input but does not support null characters in -the regular expression and will fall back to the CPU in this case. - -#### Qualifiers with nothing to repeat - -Spark supports qualifiers in cases where there is nothing to repeat. For example, Spark supports `a*+` and this -will match all inputs. The GPU implementation of RLike does not support this syntax and will throw an exception with -the message `nothing to repeat at position 0`. - -#### Stricter escaping requirements +## Regular Expressions -The GPU implementation of RLike has stricter requirements around escaping special characters in some cases. +### regexp_replace -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `a[-+]` | `a-` | Match | No Match | -| `a[\-\+]` | `a-` | Match | Match | +The RAPIDS Accelerator for Apache Spark currently supports string literal matches, not wildcard +matches for the `regexp_replace` function and will fall back to CPU if a regular expression pattern +is provided. -#### Empty groups +### RLike -The GPU implementation of RLike does not support empty groups correctly. +The GPU implementation of RLike has the following known issues where behavior is not consistent with Apache Spark and +this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. -| Pattern | Input | Spark on CPU | Spark on GPU | -|-----------|--------|--------------|--------------| -| `z()?` | `a` | No Match | Match | -| `z()*` | `a` | No Match | Match | +- `.` matches `\r` on the GPU but not on the CPU ([cuDF issue #9619](https://github.com/rapidsai/cudf/issues/9619)) +- `$` does not match the end of string if the string ends with a line-terminator + ([cuDF issue #9620](https://github.com/rapidsai/cudf/issues/9620)) +`RLike` will fall back to CPU if any regular expressions are detected that are not supported on the GPU +or would produce different results on the GPU. ## Timestamps From 62ea0e0647ae7f7d01abedbd542e09bd21cffbd2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 14:58:02 -0600 Subject: [PATCH 03/14] code cleanup and documentation --- .../spark/sql/rapids/stringFunctions.scala | 135 ++++++++++-------- .../RegularExpressionTranspilerSuite.scala | 5 +- 2 files changed, 78 insertions(+), 62 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 456544e41d1..45c794b7d76 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -772,9 +772,20 @@ class GpuRLikeMeta( } /** - * Regular expression parser. + * Regular expression parser based on a Pratt Parser design. * - * Suggested reading before making changes to this code: + * The goal of this parser is to build a minimal AST that allows us + * to validate that we can support the expression on the GPU. The goal + * is not to parse with the level of detail that would be required if + * we were building an evaluation engine. For example, operator precedence is + * largely ignored but could be added if we need it later. + * + * The Java and cuDF regular expression documentation has been used as a reference: + * + * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html + * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html + * + * The following blog posts provide some background on Pratt Parsers and parsing regex. * * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ @@ -831,7 +842,6 @@ class RegexParser(pattern: String) { } private def parseFactor(): RegexAST = { - // TODO rewrite this var base = parseBase() while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') || isValidQuantifierAhead())) { @@ -868,10 +878,6 @@ class RegexParser(pattern: String) { RegexGroup(term) } - /** - * Parse a character class as defined in the cuDF documentation at - * https://docs.rapids.ai/api/libcudf/stable/md_regex.html - */ private def parseCharacterClass(): RegexCharacterClass = { val start = i val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) @@ -889,9 +895,8 @@ class RegexParser(pattern: String) { // Negates the character class, causing it to match a single character not listed in // the character class. Only valid immediately after the opening '[' characterClass.negated = true - case '\n' | '\r' | '\t' | '\b' | '\f' => - // TODO add \a here as well - // Add this special character to the character class + case '\n' | '\r' | '\t' | '\b' | '\f' | '\007' => + // treat as a literal character and add to the character class characterClass.append(ch) case '\\' => peek() match { @@ -914,9 +919,9 @@ class RegexParser(pattern: String) { peek() match { case Some('-') => consumeExpected('-') - // TODO look for other non valid chars like escape peek() match { case Some(']') => + // '-' at end of class e.g. "[abc-]" characterClass.append(ch) characterClass.append('-') case Some(end) => @@ -945,21 +950,23 @@ class RegexParser(pattern: String) { * {n,m} (only valid if m >= n) */ private def parseQuantifierOrLiteralBrace(): RegexAST = { - //TODO refactor to avoid code duplication + // assumes that '{' has already been consumed val start = i + + def treatAsLiteralBrace() = { + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + } + consumeInt match { - case None => - // this was not a quantifier, just a literal '{' - i = start + 1 - RegexChar('{') case Some(minLength) => peek() match { case Some(',') => consumeExpected(',') val max = consumeInt() if (peek().contains('}')) { - // end of quantifier consumeExpected('}') max match { case None => @@ -968,25 +975,20 @@ class RegexParser(pattern: String) { if (m >= minLength) { QuantifierVariableLength(minLength, max) } else { - // this was not a quantifier, just a literal '{' - i = start + 1 - RegexChar('{') + treatAsLiteralBrace() } } } else { - // this was not a quantifier, just a literal '{' - i = start + 1 - RegexChar('{') + treatAsLiteralBrace() } case Some('}') => - // end of quantifier consumeExpected('}') QuantifierFixedLength(minLength) case _ => - // this was not a quantifier, just a literal '{' - i = start + 1 - RegexChar('{') + treatAsLiteralBrace() } + case None => + treatAsLiteralBrace() } } @@ -998,7 +1000,7 @@ class RegexParser(pattern: String) { consumeExpected(ch) ch match { case 'A' | 'Z' => - // anchors BOL / EOL + // BOL / EOL anchors RegexEscaped(ch) case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => // meta sequences @@ -1014,7 +1016,6 @@ class RegexParser(pattern: String) { case _ if ch.isDigit => parseOctalDigit case other => - //TODO handle this in transpiler not in parser throw new RegexUnsupportedException( s"invalid or unsupported escape character '$other'", Some(i - 1)) } @@ -1113,90 +1114,108 @@ class RegexParser(pattern: String) { * if this is not possible. */ class CudfRegexTranspiler { + + val nothingToRepeat = "nothing to repeat" def transpile(pattern: String): String = { // parse the source regular expression val regex = new RegexParser(pattern).parse() - // validatge that the regex is supported by cuDF + // validate that the regex is supported by cuDF validate(regex) // write out to regex string, performing minor transformations + // such as adding additional escaping regex.toRegexString } private def validate(regex: RegexAST): Unit = { regex match { - case RegexGroup(RegexSequence(parts)) => - if (parts.isEmpty) { - throw new RegexUnsupportedException("cuDF does not support empty groups") - } - val s = parts.map(_.toRegexString).mkString - if (s.startsWith("?") && !s.startsWith("?:")) { - throw new RegexUnsupportedException("nothing to repeat") - } + case RegexGroup(RegexSequence(parts)) if parts.isEmpty => + // example: "()" + throw new RegexUnsupportedException("cuDF does not support empty groups") case RegexRepetition(RegexEscaped(_), _) => - throw new RegexUnsupportedException("nothing to repeat") - case RegexRepetition(RegexChar(a), _) if "$^.".contains(a) => - throw new RegexUnsupportedException("nothing to repeat") + // example: "\B?" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexRepetition(RegexChar(a), _) if "$^".contains(a) => + // example: "$*" + throw new RegexUnsupportedException(nothingToRepeat) case RegexRepetition(RegexRepetition(_, _), _) => - // cuDF generally does not support nested repetitions such as possessive quantifiers ("a*+") - throw new RegexUnsupportedException("nothing to repeat") + // example: "a*+" + throw new RegexUnsupportedException(nothingToRepeat) case RegexChoice(l, r) => (l, r) match { // check for empty left-hand side caused by ^ or $ or a repetition case (RegexSequence(a), _) => a.lastOption match { case None => - throw new RegexUnsupportedException("nothing to repeat") + // example: "|a" + throw new RegexUnsupportedException(nothingToRepeat) case Some(RegexChar(ch)) if ch == '$' || ch == '^' => - throw new RegexUnsupportedException("nothing to repeat") + // example: "^|a" + throw new RegexUnsupportedException(nothingToRepeat) case Some(RegexRepetition(_, _)) => - throw new RegexUnsupportedException("nothing to repeat") + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) case _ => } // check for empty right-hand side caused by ^ or $ case (_, RegexSequence(b)) => b.headOption match { case None => - throw new RegexUnsupportedException("nothing to repeat") + // example: "|b" + throw new RegexUnsupportedException(nothingToRepeat) case Some(RegexChar(ch)) if ch == '$' || ch == '^' => - throw new RegexUnsupportedException("nothing to repeat") + // example: "a|$" + throw new RegexUnsupportedException(nothingToRepeat) case _ => } case (RegexRepetition(_, _), _) => - throw new RegexUnsupportedException("nothing to repeat") + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) case _ => } - //TODO do not use toRegexString here case RegexSequence(parts) => - if (parts.head.toRegexString == "|" || parts.last.toRegexString == "|") { - throw new RegexUnsupportedException("nothing to repeat") + if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { + // examples: "a|", "|b" + throw new RegexUnsupportedException(nothingToRepeat) } - //TODO this is too hacky - if (parts.head.toRegexString.startsWith("{")) { + if (isRegexChar(parts.head, '{')) { + // example: "{" // cuDF would treat this as a quantifier even though in this // context (being at the start of a sequence) it is not quantifying anything - throw new RegexUnsupportedException("nothing to repeat") + // note that we could choose to escape this in the transpiler rather than + // falling back to CPU + throw new RegexUnsupportedException(nothingToRepeat) } parts.foreach { case RegexEscaped(ch) if ch == 'b' || ch == 'B' => + // example: "a\Bb" // this needs further analysis to determine why words boundaries behave - // differently betwee Java and cuDF + // differently between Java and cuDF throw new RegexUnsupportedException("word boundaries are not supported") case _ => } - case RegexCharacterClass(negated, characters) => + case RegexCharacterClass(_, characters) => characters.foreach { case RegexChar(ch) if ch == '[' || ch == ']' => - // this can have very different semantics between Java and cuDF + // examples: + // - "[a[]" should match the literal characters "a" and "[" + // - "[a-b[c-d]]" is supported by Java but not cuDF throw new RegexUnsupportedException("nested character classes are not supported") case _ => } case _ => } + + // walk down the tree and validate children regex.children().foreach(validate) } + + private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match { + case RegexChar(ch) => ch == value + case _ => false + } } sealed trait RegexAST { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 1d5723cccd5..6fa2ec37d20 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -73,10 +73,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support OR at BOL / EOL") { val patterns = Seq("$|a", "^|a") - // "a|$" ? - // "a|^" ? patterns.foreach(pattern => { - println(pattern) assertUnsupported(pattern, "nothing to repeat") }) } @@ -88,7 +85,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("nothing to repeat") { - val patterns = Seq("$*", "^+", ".*") + val patterns = Seq("$*", "^+") patterns.foreach(pattern => assertUnsupported(pattern, "nothing to repeat")) } From b4e4fd4ca43cc2fee0a2ad90eee346606f9726e0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 15:05:09 -0600 Subject: [PATCH 04/14] remove adhoc test --- .../nvidia/spark/rapids/RegularExpressionParserSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index a2a93e2eff5..c255ca1687f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -43,10 +43,6 @@ class RegularExpressionParserSuite extends FunSuite { SimpleQuantifier('+'))))) } - test("adhoc") { - println(parse("ab*$}\\B? ")) - } - private def parse(pattern: String): RegexAST = { new RegexParser(pattern).parse() } From a00215c9d62c5f0ebcadfe96da5a30141e37847e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 15:24:52 -0600 Subject: [PATCH 05/14] more parser tests --- .../rapids/RegularExpressionParserSuite.scala | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index c255ca1687f..916b5bcc14a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -19,7 +19,7 @@ import scala.collection.mutable.ListBuffer import org.scalatest.FunSuite -import org.apache.spark.sql.rapids.{QuantifierFixedLength, RegexAST, RegexChar, RegexParser, RegexRepetition, RegexSequence, SimpleQuantifier} +import org.apache.spark.sql.rapids.{QuantifierFixedLength, RegexAST, RegexChar, RegexCharacterClass, RegexCharacterRange, RegexChoice, RegexGroup, RegexParser, RegexRepetition, RegexSequence, SimpleQuantifier} class RegularExpressionParserSuite extends FunSuite { @@ -40,7 +40,30 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer( RegexRepetition( RegexRepetition(RegexChar('a'), SimpleQuantifier('*')), - SimpleQuantifier('+'))))) + SimpleQuantifier('+'))))) + } + + test("choice") { + assert(parse("a|b") === + RegexChoice(RegexSequence(ListBuffer(RegexChar('a'))), + RegexSequence(ListBuffer(RegexChar('b'))))) + } + + test("group") { + assert(parse("(a)(b)") === + RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer(RegexChar('a')))), + RegexGroup(RegexSequence(ListBuffer(RegexChar('b'))))))) + } + + test("character class") { + assert(parse("[a-z+A-Z]") === + RegexSequence(ListBuffer( + RegexCharacterClass(negated = false, + ListBuffer( + RegexCharacterRange('a', 'z'), + RegexChar('+'), + RegexCharacterRange('A', 'Z')))))) } private def parse(pattern: String): RegexAST = { From 5d0d8d7b2ef9bac61ddf38faa212bf4b7b6ab654 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 5 Nov 2021 15:51:12 -0600 Subject: [PATCH 06/14] revert accidental docs change --- docs/compatibility.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 95591c185d3..fbd286d76fb 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -259,7 +259,7 @@ case. ## LIKE -If a null char '\0' is in a string that is being matched by a regular expression, 'LIKE' sees it as +If a null char '\0' is in a string that is being matched by a regular expression, `LIKE` sees it as the end of the string. This will be fixed in a future release. The issue is [here](https://github.com/NVIDIA/spark-rapids/issues/119). From 3f8d2ae09bb16be2eb5b418217909d6d88192546 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 17:15:38 -0700 Subject: [PATCH 07/14] Move regular expression parser to new source file Signed-off-by: Andy Grove --- .../com/nvidia/spark/rapids/RegexParser.scala | 609 ++++++++++++++++++ .../spark/sql/rapids/stringFunctions.scala | 593 +---------------- .../rapids/RegularExpressionParserSuite.scala | 1 - .../RegularExpressionTranspilerSuite.scala | 3 - 4 files changed, 610 insertions(+), 596 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala new file mode 100644 index 00000000000..bbc21a3940c --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -0,0 +1,609 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.sql.SQLException + +import scala.collection.mutable.ListBuffer + +/** + * Regular expression parser based on a Pratt Parser design. + * + * The goal of this parser is to build a minimal AST that allows us + * to validate that we can support the expression on the GPU. The goal + * is not to parse with the level of detail that would be required if + * we were building an evaluation engine. For example, operator precedence is + * largely ignored but could be added if we need it later. + * + * The Java and cuDF regular expression documentation has been used as a reference: + * + * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html + * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html + * + * The following blog posts provide some background on Pratt Parsers and parsing regex. + * + * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ + * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ + */ +class RegexParser(pattern: String) { + + /** index of current position within the string being parsed */ + private var i = 0 + + def parse(): RegexAST = { + val ast = parseInternal() + if (!eof()) { + throw new RegexUnsupportedException("failed to parse full regex") + } + ast + } + + private def parseInternal(): RegexAST = { + val term = parseTerm(() => peek().contains('|')) + if (!eof() && peek().contains('|')) { + consumeExpected('|') + RegexChoice(term, parseInternal()) + } else { + term + } + } + + private def parseTerm(until: () => Boolean): RegexAST = { + val sequence = RegexSequence(new ListBuffer()) + while (!eof() && !until()) { + parseFactor() match { + case RegexSequence(parts) => + sequence.parts ++= parts + case other => + sequence.parts += other + } + } + sequence + } + + private def isValidQuantifierAhead(): Boolean = { + if (peek().contains('{')) { + val bookmark = i + consumeExpected('{') + val q = parseQuantifierOrLiteralBrace() + i = bookmark + q match { + case _: QuantifierFixedLength | _: QuantifierVariableLength => true + case _ => false + } + } else { + false + } + } + + private def parseFactor(): RegexAST = { + var base = parseBase() + while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') + || isValidQuantifierAhead())) { + + if (peek().contains('{')) { + consumeExpected('{') + base = RegexRepetition(base, parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier]) + } else { + base = RegexRepetition(base, SimpleQuantifier(consume())) + } + } + base + } + + private def parseBase(): RegexAST = { + consume() match { + case '(' => + parseGroup() + case '[' => + parseCharacterClass() + case '\\' => + parseEscapedCharacter() + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(i)) + case other => + RegexChar(other) + } + } + + private def parseGroup(): RegexAST = { + val term = parseTerm(() => peek().contains(')')) + consumeExpected(')') + RegexGroup(term) + } + + private def parseCharacterClass(): RegexCharacterClass = { + val start = i + val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) + // loop until the end of the character class or EOF + var characterClassComplete = false + while (!eof() && !characterClassComplete) { + val ch = consume() + ch match { + case '[' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case ']' => + characterClassComplete = true + case '^' if i == start + 1 => + // Negates the character class, causing it to match a single character not listed in + // the character class. Only valid immediately after the opening '[' + characterClass.negated = true + case '\n' | '\r' | '\t' | '\b' | '\f' | '\007' => + // treat as a literal character and add to the character class + characterClass.append(ch) + case '\\' => + peek() match { + case None => + throw new RegexUnsupportedException( + s"unexpected EOF while parsing escaped character", Some(i)) + case Some(ch) => + ch match { + case '\\' | '^' | '-' | ']' | '+' => + // escaped metacharacter within character class + characterClass.appendEscaped(consumeExpected(ch)) + } + } + case '\u0000' => + throw new RegexUnsupportedException( + "cuDF does not support null characters in regular expressions", Some(i)) + case _ => + // check for range + val start = ch + peek() match { + case Some('-') => + consumeExpected('-') + peek() match { + case Some(']') => + // '-' at end of class e.g. "[abc-]" + characterClass.append(ch) + characterClass.append('-') + case Some(end) => + skip() + characterClass.appendRange(start, end) + } + case _ => + // treat as supported literal character + characterClass.append(ch) + } + } + } + if (!characterClassComplete) { + throw new RegexUnsupportedException( + s"unexpected EOF while parsing character class", Some(i)) + } + characterClass + } + + + /** + * Parse a quantifier in one of the following formats: + * + * {n} + * {n,} + * {n,m} (only valid if m >= n) + */ + private def parseQuantifierOrLiteralBrace(): RegexAST = { + + // assumes that '{' has already been consumed + val start = i + + def treatAsLiteralBrace() = { + // this was not a quantifier, just a literal '{' + i = start + 1 + RegexChar('{') + } + + consumeInt match { + case Some(minLength) => + peek() match { + case Some(',') => + consumeExpected(',') + val max = consumeInt() + if (peek().contains('}')) { + consumeExpected('}') + max match { + case None => + QuantifierVariableLength(minLength, None) + case Some(m) => + if (m >= minLength) { + QuantifierVariableLength(minLength, max) + } else { + treatAsLiteralBrace() + } + } + } else { + treatAsLiteralBrace() + } + case Some('}') => + consumeExpected('}') + QuantifierFixedLength(minLength) + case _ => + treatAsLiteralBrace() + } + case None => + treatAsLiteralBrace() + } + } + + private def parseEscapedCharacter(): RegexAST = { + peek() match { + case None => + throw new RegexUnsupportedException("escape at end of string", Some(i)) + case Some(ch) => + consumeExpected(ch) + ch match { + case 'A' | 'Z' => + // BOL / EOL anchors + RegexEscaped(ch) + case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => + // meta sequences + RegexEscaped(ch) + case 'B' | 'b' => + // word boundaries + RegexEscaped(ch) + case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => + // escaped metacharacter + RegexEscaped(ch) + case 'x' => + parseHexDigit + case _ if ch.isDigit => + parseOctalDigit + case other => + throw new RegexUnsupportedException( + s"invalid or unsupported escape character '$other'", Some(i - 1)) + } + } + } + + private def parseHexDigit: RegexAST = { + + def isHexDigit(ch: Char): Boolean = ch.isDigit || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') + + if (i + 1 < pattern.length + && isHexDigit(pattern.charAt(i)) + && isHexDigit(pattern.charAt(i+1))) { + val hex = pattern.substring(i, i + 2) + i += 2 + RegexHexChar(hex) + } else { + throw new RegexUnsupportedException( + "Invalid hex digit", Some(i)) + } + } + + private def parseOctalDigit = { + + if (i + 2 < pattern.length + && pattern.charAt(i).isDigit + && pattern.charAt(i+1).isDigit + && pattern.charAt(i+2).isDigit) { + val hex = pattern.substring(i, i + 2) + i += 3 + RegexOctalChar(hex) + } else { + throw new RegexUnsupportedException( + "Invalid octal digit", Some(i)) + } + } + + /** Determine if we are at the end of the input */ + private def eof(): Boolean = i == pattern.length + + /** Advance the index by one */ + private def skip(): Unit = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + } + i += 1 + } + + /** Get the next character and advance the index by one */ + private def consume(): Char = { + if (eof()) { + throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + } else { + i += 1 + pattern.charAt(i - 1) + } + } + + /** Consume the next character if it is the one we expect */ + private def consumeExpected(expected: Char): Char = { + val consumed = consume() + if (consumed != expected) { + throw new RegexUnsupportedException( + s"Expected '$expected' but found '$consumed'", Some(i-1)) + } + consumed + } + + /** Peek at the next character without consuming it */ + private def peek(): Option[Char] = { + if (eof()) { + None + } else { + Some(pattern.charAt(i)) + } + } + + private def consumeInt(): Option[Int] = { + val start = i + while (!eof() && peek().exists(_.isDigit)) { + skip() + } + if (start == i) { + None + } else { + Some(pattern.substring(start, i).toInt) + } + } + +} + +/** + * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception + * if this is not possible. + */ +class CudfRegexTranspiler { + + val nothingToRepeat = "nothing to repeat" + + def transpile(pattern: String): String = { + // parse the source regular expression + val regex = new RegexParser(pattern).parse() + // validate that the regex is supported by cuDF + validate(regex) + // write out to regex string, performing minor transformations + // such as adding additional escaping + regex.toRegexString + } + + private def validate(regex: RegexAST): Unit = { + regex match { + case RegexGroup(RegexSequence(parts)) if parts.isEmpty => + // example: "()" + throw new RegexUnsupportedException("cuDF does not support empty groups") + case RegexRepetition(RegexEscaped(_), _) => + // example: "\B?" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexRepetition(RegexChar(a), _) if "$^".contains(a) => + // example: "$*" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexRepetition(RegexRepetition(_, _), _) => + // example: "a*+" + throw new RegexUnsupportedException(nothingToRepeat) + case RegexChoice(l, r) => + (l, r) match { + // check for empty left-hand side caused by ^ or $ or a repetition + case (RegexSequence(a), _) => + a.lastOption match { + case None => + // example: "|a" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + // example: "^|a" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexRepetition(_, _)) => + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + // check for empty right-hand side caused by ^ or $ + case (_, RegexSequence(b)) => + b.headOption match { + case None => + // example: "|b" + throw new RegexUnsupportedException(nothingToRepeat) + case Some(RegexChar(ch)) if ch == '$' || ch == '^' => + // example: "a|$" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + case (RegexRepetition(_, _), _) => + // example: "a*|a" + throw new RegexUnsupportedException(nothingToRepeat) + case _ => + } + + case RegexSequence(parts) => + if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { + // examples: "a|", "|b" + throw new RegexUnsupportedException(nothingToRepeat) + } + if (isRegexChar(parts.head, '{')) { + // example: "{" + // cuDF would treat this as a quantifier even though in this + // context (being at the start of a sequence) it is not quantifying anything + // note that we could choose to escape this in the transpiler rather than + // falling back to CPU + throw new RegexUnsupportedException(nothingToRepeat) + } + parts.foreach { + case RegexEscaped(ch) if ch == 'b' || ch == 'B' => + // example: "a\Bb" + // this needs further analysis to determine why words boundaries behave + // differently between Java and cuDF + throw new RegexUnsupportedException("word boundaries are not supported") + case _ => + } + case RegexCharacterClass(_, characters) => + characters.foreach { + case RegexChar(ch) if ch == '[' || ch == ']' => + // examples: + // - "[a[]" should match the literal characters "a" and "[" + // - "[a-b[c-d]]" is supported by Java but not cuDF + throw new RegexUnsupportedException("nested character classes are not supported") + case _ => + } + + case _ => + } + + // walk down the tree and validate children + regex.children().foreach(validate) + } + + private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match { + case RegexChar(ch) => ch == value + case _ => false + } +} + +sealed trait RegexAST { + def children(): Seq[RegexAST] + def toRegexString: String +} + +sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { + override def children(): Seq[RegexAST] = parts + override def toRegexString: String = parts.map(_.toRegexString).mkString +} + +sealed case class RegexGroup(term: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(term) + override def toRegexString: String = s"(${term.toRegexString})" +} + +sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a, b) + override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}" +} + +sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST { + override def children(): Seq[RegexAST] = Seq(a) + override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}" +} + +sealed trait RegexQuantifier extends RegexAST + +sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = ch.toString +} + +sealed case class QuantifierFixedLength(length: Int) + extends RegexQuantifier { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + s"{$length}" + } +} + +sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int]) + extends RegexQuantifier{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = { + maxLength match { + case Some(max) => + s"{$minLength,$max}" + case _ => + s"{$minLength,}" + } + } +} + +sealed trait RegexCharacterClassComponent extends RegexAST + +sealed case class RegexHexChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\x$a" +} + +sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$a" +} + +sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\$a" +} + +sealed case class RegexCharacterRange(start: Char, end: Char) + extends RegexCharacterClassComponent{ + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"$start-$end" +} + +sealed case class RegexCharacterClass( + var negated: Boolean, + var characters: ListBuffer[RegexCharacterClassComponent]) + extends RegexAST { + + override def children(): Seq[RegexAST] = characters + def append(ch: Char): Unit = { + characters += RegexChar(ch) + } + + def appendEscaped(ch: Char): Unit = { + characters += RegexEscaped(ch) + } + + def appendRange(start: Char, end: Char): Unit = { + characters += RegexCharacterRange(start, end) + } + + override def toRegexString: String = { + val builder = new StringBuilder("[") + if (negated) { + builder.append("^") + } + for (a <- characters) { + a match { + case RegexChar(ch) if requiresEscaping(ch) => + // cuDF has stricter escaping requirements for certain characters + // within a character class compared to Java or Python regex + builder.append(s"\\$ch") + case other => + builder.append(other.toRegexString) + } + } + builder.append("]") + builder.toString() + } + + private def requiresEscaping(ch: Char): Boolean = { + // there are likely other cases that we will need to add here but this + // covers everything we have seen so far during fuzzing + ch match { + case '-' => + // cuDF requires '-' to be escaped when used as a character within a character + // to disambiguate from the character range syntax 'a-b' + true + case _ => + false + } + } +} + +class RegexUnsupportedException(message: String, index: Option[Int] = None) + extends SQLException { + override def getMessage: String = { + index match { + case Some(i) => s"$message at index $index" + case _ => message + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 45c794b7d76..bc413b48e58 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -16,9 +16,7 @@ package org.apache.spark.sql.rapids -import java.sql.SQLException - -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ColumnView, DType, PadSide, Scalar, Table} import com.nvidia.spark.rapids._ @@ -771,595 +769,6 @@ class GpuRLikeMeta( GpuRLike(lhs, rhs) } -/** - * Regular expression parser based on a Pratt Parser design. - * - * The goal of this parser is to build a minimal AST that allows us - * to validate that we can support the expression on the GPU. The goal - * is not to parse with the level of detail that would be required if - * we were building an evaluation engine. For example, operator precedence is - * largely ignored but could be added if we need it later. - * - * The Java and cuDF regular expression documentation has been used as a reference: - * - * Java regex: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html - * cuDF regex: https://docs.rapids.ai/api/libcudf/stable/md_regex.html - * - * The following blog posts provide some background on Pratt Parsers and parsing regex. - * - * - https://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ - * - https://matt.might.net/articles/parsing-regex-with-recursive-descent/ - */ -class RegexParser(pattern: String) { - - /** index of current position within the string being parsed */ - private var i = 0 - - def parse(): RegexAST = { - val ast = parseInternal() - if (!eof()) { - throw new RegexUnsupportedException("failed to parse full regex") - } - ast - } - - private def parseInternal(): RegexAST = { - val term = parseTerm(() => peek().contains('|')) - if (!eof() && peek().contains('|')) { - consumeExpected('|') - RegexChoice(term, parseInternal()) - } else { - term - } - } - - private def parseTerm(until: () => Boolean): RegexAST = { - val sequence = RegexSequence(new ListBuffer()) - while (!eof() && !until()) { - parseFactor() match { - case RegexSequence(parts) => - sequence.parts ++= parts - case other => - sequence.parts += other - } - } - sequence - } - - private def isValidQuantifierAhead(): Boolean = { - if (peek().contains('{')) { - val bookmark = i - consumeExpected('{') - val q = parseQuantifierOrLiteralBrace() - i = bookmark - q match { - case _: QuantifierFixedLength | _: QuantifierVariableLength => true - case _ => false - } - } else { - false - } - } - - private def parseFactor(): RegexAST = { - var base = parseBase() - while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') - || isValidQuantifierAhead())) { - - if (peek().contains('{')) { - consumeExpected('{') - base = RegexRepetition(base, parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier]) - } else { - base = RegexRepetition(base, SimpleQuantifier(consume())) - } - } - base - } - - private def parseBase(): RegexAST = { - consume() match { - case '(' => - parseGroup() - case '[' => - parseCharacterClass() - case '\\' => - parseEscapedCharacter() - case '\u0000' => - throw new RegexUnsupportedException( - "cuDF does not support null characters in regular expressions", Some(i)) - case other => - RegexChar(other) - } - } - - private def parseGroup(): RegexAST = { - val term = parseTerm(() => peek().contains(')')) - consumeExpected(')') - RegexGroup(term) - } - - private def parseCharacterClass(): RegexCharacterClass = { - val start = i - val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) - // loop until the end of the character class or EOF - var characterClassComplete = false - while (!eof() && !characterClassComplete) { - val ch = consume() - ch match { - case '[' => - // treat as a literal character and add to the character class - characterClass.append(ch) - case ']' => - characterClassComplete = true - case '^' if i == start + 1 => - // Negates the character class, causing it to match a single character not listed in - // the character class. Only valid immediately after the opening '[' - characterClass.negated = true - case '\n' | '\r' | '\t' | '\b' | '\f' | '\007' => - // treat as a literal character and add to the character class - characterClass.append(ch) - case '\\' => - peek() match { - case None => - throw new RegexUnsupportedException( - s"unexpected EOF while parsing escaped character", Some(i)) - case Some(ch) => - ch match { - case '\\' | '^' | '-' | ']' | '+' => - // escaped metacharacter within character class - characterClass.appendEscaped(consumeExpected(ch)) - } - } - case '\u0000' => - throw new RegexUnsupportedException( - "cuDF does not support null characters in regular expressions", Some(i)) - case _ => - // check for range - val start = ch - peek() match { - case Some('-') => - consumeExpected('-') - peek() match { - case Some(']') => - // '-' at end of class e.g. "[abc-]" - characterClass.append(ch) - characterClass.append('-') - case Some(end) => - skip() - characterClass.appendRange(start, end) - } - case _ => - // treat as supported literal character - characterClass.append(ch) - } - } - } - if (!characterClassComplete) { - throw new RegexUnsupportedException( - s"unexpected EOF while parsing character class", Some(i)) - } - characterClass - } - - - /** - * Parse a quantifier in one of the following formats: - * - * {n} - * {n,} - * {n,m} (only valid if m >= n) - */ - private def parseQuantifierOrLiteralBrace(): RegexAST = { - - // assumes that '{' has already been consumed - val start = i - - def treatAsLiteralBrace() = { - // this was not a quantifier, just a literal '{' - i = start + 1 - RegexChar('{') - } - - consumeInt match { - case Some(minLength) => - peek() match { - case Some(',') => - consumeExpected(',') - val max = consumeInt() - if (peek().contains('}')) { - consumeExpected('}') - max match { - case None => - QuantifierVariableLength(minLength, None) - case Some(m) => - if (m >= minLength) { - QuantifierVariableLength(minLength, max) - } else { - treatAsLiteralBrace() - } - } - } else { - treatAsLiteralBrace() - } - case Some('}') => - consumeExpected('}') - QuantifierFixedLength(minLength) - case _ => - treatAsLiteralBrace() - } - case None => - treatAsLiteralBrace() - } - } - - private def parseEscapedCharacter(): RegexAST = { - peek() match { - case None => - throw new RegexUnsupportedException("escape at end of string", Some(i)) - case Some(ch) => - consumeExpected(ch) - ch match { - case 'A' | 'Z' => - // BOL / EOL anchors - RegexEscaped(ch) - case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => - // meta sequences - RegexEscaped(ch) - case 'B' | 'b' => - // word boundaries - RegexEscaped(ch) - case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => - // escaped metacharacter - RegexEscaped(ch) - case 'x' => - parseHexDigit - case _ if ch.isDigit => - parseOctalDigit - case other => - throw new RegexUnsupportedException( - s"invalid or unsupported escape character '$other'", Some(i - 1)) - } - } - } - - private def parseHexDigit: RegexAST = { - - def isHexDigit(ch: Char): Boolean = ch.isDigit || - (ch >= 'a' && ch <= 'f') || - (ch >= 'A' && ch <= 'F') - - if (i + 1 < pattern.length - && isHexDigit(pattern.charAt(i)) - && isHexDigit(pattern.charAt(i+1))) { - val hex = pattern.substring(i, i + 2) - i += 2 - RegexHexChar(hex) - } else { - throw new RegexUnsupportedException( - "Invalid hex digit", Some(i)) - } - } - - private def parseOctalDigit = { - - if (i + 2 < pattern.length - && pattern.charAt(i).isDigit - && pattern.charAt(i+1).isDigit - && pattern.charAt(i+2).isDigit) { - val hex = pattern.substring(i, i + 2) - i += 3 - RegexOctalChar(hex) - } else { - throw new RegexUnsupportedException( - "Invalid octal digit", Some(i)) - } - } - - /** Determine if we are at the end of the input */ - private def eof(): Boolean = i == pattern.length - - /** Advance the index by one */ - private def skip(): Unit = { - if (eof()) { - throw new RegexUnsupportedException("Unexpected EOF", Some(i)) - } - i += 1 - } - - /** Get the next character and advance the index by one */ - private def consume(): Char = { - if (eof()) { - throw new RegexUnsupportedException("Unexpected EOF", Some(i)) - } else { - i += 1 - pattern.charAt(i - 1) - } - } - - /** Consume the next character if it is the one we expect */ - private def consumeExpected(expected: Char): Char = { - val consumed = consume() - if (consumed != expected) { - throw new RegexUnsupportedException( - s"Expected '$expected' but found '$consumed'", Some(i-1)) - } - consumed - } - - /** Peek at the next character without consuming it */ - private def peek(): Option[Char] = { - if (eof()) { - None - } else { - Some(pattern.charAt(i)) - } - } - - private def consumeInt(): Option[Int] = { - val start = i - while (!eof() && peek().exists(_.isDigit)) { - skip() - } - if (start == i) { - None - } else { - Some(pattern.substring(start, i).toInt) - } - } - -} - -/** - * Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception - * if this is not possible. - */ -class CudfRegexTranspiler { - - val nothingToRepeat = "nothing to repeat" - - def transpile(pattern: String): String = { - // parse the source regular expression - val regex = new RegexParser(pattern).parse() - // validate that the regex is supported by cuDF - validate(regex) - // write out to regex string, performing minor transformations - // such as adding additional escaping - regex.toRegexString - } - - private def validate(regex: RegexAST): Unit = { - regex match { - case RegexGroup(RegexSequence(parts)) if parts.isEmpty => - // example: "()" - throw new RegexUnsupportedException("cuDF does not support empty groups") - case RegexRepetition(RegexEscaped(_), _) => - // example: "\B?" - throw new RegexUnsupportedException(nothingToRepeat) - case RegexRepetition(RegexChar(a), _) if "$^".contains(a) => - // example: "$*" - throw new RegexUnsupportedException(nothingToRepeat) - case RegexRepetition(RegexRepetition(_, _), _) => - // example: "a*+" - throw new RegexUnsupportedException(nothingToRepeat) - case RegexChoice(l, r) => - (l, r) match { - // check for empty left-hand side caused by ^ or $ or a repetition - case (RegexSequence(a), _) => - a.lastOption match { - case None => - // example: "|a" - throw new RegexUnsupportedException(nothingToRepeat) - case Some(RegexChar(ch)) if ch == '$' || ch == '^' => - // example: "^|a" - throw new RegexUnsupportedException(nothingToRepeat) - case Some(RegexRepetition(_, _)) => - // example: "a*|a" - throw new RegexUnsupportedException(nothingToRepeat) - case _ => - } - // check for empty right-hand side caused by ^ or $ - case (_, RegexSequence(b)) => - b.headOption match { - case None => - // example: "|b" - throw new RegexUnsupportedException(nothingToRepeat) - case Some(RegexChar(ch)) if ch == '$' || ch == '^' => - // example: "a|$" - throw new RegexUnsupportedException(nothingToRepeat) - case _ => - } - case (RegexRepetition(_, _), _) => - // example: "a*|a" - throw new RegexUnsupportedException(nothingToRepeat) - case _ => - } - - case RegexSequence(parts) => - if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { - // examples: "a|", "|b" - throw new RegexUnsupportedException(nothingToRepeat) - } - if (isRegexChar(parts.head, '{')) { - // example: "{" - // cuDF would treat this as a quantifier even though in this - // context (being at the start of a sequence) it is not quantifying anything - // note that we could choose to escape this in the transpiler rather than - // falling back to CPU - throw new RegexUnsupportedException(nothingToRepeat) - } - parts.foreach { - case RegexEscaped(ch) if ch == 'b' || ch == 'B' => - // example: "a\Bb" - // this needs further analysis to determine why words boundaries behave - // differently between Java and cuDF - throw new RegexUnsupportedException("word boundaries are not supported") - case _ => - } - case RegexCharacterClass(_, characters) => - characters.foreach { - case RegexChar(ch) if ch == '[' || ch == ']' => - // examples: - // - "[a[]" should match the literal characters "a" and "[" - // - "[a-b[c-d]]" is supported by Java but not cuDF - throw new RegexUnsupportedException("nested character classes are not supported") - case _ => - } - - case _ => - } - - // walk down the tree and validate children - regex.children().foreach(validate) - } - - private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match { - case RegexChar(ch) => ch == value - case _ => false - } -} - -sealed trait RegexAST { - def children(): Seq[RegexAST] - def toRegexString: String -} - -sealed case class RegexSequence(parts: ListBuffer[RegexAST]) extends RegexAST { - override def children(): Seq[RegexAST] = parts - override def toRegexString: String = parts.map(_.toRegexString).mkString -} - -sealed case class RegexGroup(term: RegexAST) extends RegexAST { - override def children(): Seq[RegexAST] = Seq(term) - override def toRegexString: String = s"(${term.toRegexString})" -} - -sealed case class RegexChoice(a: RegexAST, b: RegexAST) extends RegexAST { - override def children(): Seq[RegexAST] = Seq(a, b) - override def toRegexString: String = s"${a.toRegexString}|${b.toRegexString}" -} - -sealed case class RegexRepetition(a: RegexAST, quantifier: RegexQuantifier) extends RegexAST { - override def children(): Seq[RegexAST] = Seq(a) - override def toRegexString: String = s"${a.toRegexString}${quantifier.toRegexString}" -} - -sealed trait RegexQuantifier extends RegexAST - -sealed case class SimpleQuantifier(ch: Char) extends RegexQuantifier { - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = ch.toString -} - -sealed case class QuantifierFixedLength(length: Int) - extends RegexQuantifier { - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = { - s"{$length}" - } -} - -sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int]) - extends RegexQuantifier{ - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = { - maxLength match { - case Some(max) => - s"{$minLength,$max}" - case _ => - s"{$minLength,}" - } - } -} - -sealed trait RegexCharacterClassComponent extends RegexAST - -sealed case class RegexHexChar(a: String) extends RegexCharacterClassComponent { - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"\\x$a" -} - -sealed case class RegexOctalChar(a: String) extends RegexCharacterClassComponent { - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"\\$a" -} - -sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"$a" -} - -sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"\\$a" -} - -sealed case class RegexCharacterRange(start: Char, end: Char) - extends RegexCharacterClassComponent{ - override def children(): Seq[RegexAST] = Seq.empty - override def toRegexString: String = s"$start-$end" -} - -sealed case class RegexCharacterClass( - var negated: Boolean, - var characters: ListBuffer[RegexCharacterClassComponent]) - extends RegexAST { - - override def children(): Seq[RegexAST] = characters - def append(ch: Char): Unit = { - characters += RegexChar(ch) - } - - def appendEscaped(ch: Char): Unit = { - characters += RegexEscaped(ch) - } - - def appendRange(start: Char, end: Char): Unit = { - characters += RegexCharacterRange(start, end) - } - - override def toRegexString: String = { - val builder = new StringBuilder("[") - if (negated) { - builder.append("^") - } - for (a <- characters) { - a match { - case RegexChar(ch) if requiresEscaping(ch) => - // cuDF has stricter escaping requirements for certain characters - // within a character class compared to Java or Python regex - builder.append(s"\\$ch") - case other => - builder.append(other.toRegexString) - } - } - builder.append("]") - builder.toString() - } - - private def requiresEscaping(ch: Char): Boolean = { - // there are likely other cases that we will need to add here but this - // covers everything we have seen so far during fuzzing - ch match { - case '-' => - // cuDF requires '-' to be escaped when used as a character within a character - // to disambiguate from the character range syntax 'a-b' - true - case _ => - false - } - } -} - -class RegexUnsupportedException(message: String, index: Option[Int] = None) - extends SQLException { - override def getMessage: String = { - index match { - case Some(i) => s"$message at index $index" - case _ => message - } - } -} - case class GpuRLike(left: Expression, right: Expression) extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 916b5bcc14a..37fcf9a9f1f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -19,7 +19,6 @@ import scala.collection.mutable.ListBuffer import org.scalatest.FunSuite -import org.apache.spark.sql.rapids.{QuantifierFixedLength, RegexAST, RegexChar, RegexCharacterClass, RegexCharacterRange, RegexChoice, RegexGroup, RegexParser, RegexRepetition, RegexSequence, SimpleQuantifier} class RegularExpressionParserSuite extends FunSuite { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 6fa2ec37d20..213296e86db 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -23,8 +23,6 @@ import scala.util.{Random, Try} import ai.rapids.cudf.{ColumnVector, CudfException} import org.scalatest.FunSuite -import org.apache.spark.sql.rapids.{CudfRegexTranspiler, RegexAST, RegexParser, RegexUnsupportedException} - class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF does not support choice with nothing to repeat") { @@ -36,7 +34,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("cuDF unsupported choice cases") { val input = Seq("cat", "dog") - val supportedPatterns = Seq("cat|d*") val patterns = Seq("c*|d*", "c*|dog", "[cat]{3}|dog") patterns.foreach(pattern => { val e = intercept[CudfException] { From 95695e351edf874f838a3f00a8ecbad60994270c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 17:16:12 -0700 Subject: [PATCH 08/14] Update docs/compatibility.md Co-authored-by: Jason Lowe --- docs/compatibility.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index fbd286d76fb..df026d05a6b 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -273,7 +273,7 @@ is provided. ### RLike -The GPU implementation of RLike has the following known issues where behavior is not consistent with Apache Spark and +The GPU implementation of `RLike` has the following known issues where behavior is not consistent with Apache Spark and this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`. - `.` matches `\r` on the GPU but not on the CPU ([cuDF issue #9619](https://github.com/rapidsai/cudf/issues/9619)) From aeca3103b0a8836dca93753239325defeb17f3da Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 19:33:48 -0700 Subject: [PATCH 09/14] more fully implement hex and octal parsing and address other PR feedback --- .../com/nvidia/spark/rapids/RegexParser.scala | 147 +++++++++++------- .../rapids/RegularExpressionParserSuite.scala | 22 ++- .../RegularExpressionTranspilerSuite.scala | 6 +- 3 files changed, 119 insertions(+), 56 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index bbc21a3940c..7ce2374c15e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -41,7 +41,7 @@ import scala.collection.mutable.ListBuffer class RegexParser(pattern: String) { /** index of current position within the string being parsed */ - private var i = 0 + private var pos = 0 def parse(): RegexAST = { val ast = parseInternal() @@ -76,10 +76,10 @@ class RegexParser(pattern: String) { private def isValidQuantifierAhead(): Boolean = { if (peek().contains('{')) { - val bookmark = i + val bookmark = pos consumeExpected('{') val q = parseQuantifierOrLiteralBrace() - i = bookmark + pos = bookmark q match { case _: QuantifierFixedLength | _: QuantifierVariableLength => true case _ => false @@ -92,14 +92,15 @@ class RegexParser(pattern: String) { private def parseFactor(): RegexAST = { var base = parseBase() while (!eof() && (peek().exists(ch => ch == '*' || ch == '+' || ch == '?') - || isValidQuantifierAhead())) { + || isValidQuantifierAhead())) { - if (peek().contains('{')) { + val quantifier = if (peek().contains('{')) { consumeExpected('{') - base = RegexRepetition(base, parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier]) + parseQuantifierOrLiteralBrace().asInstanceOf[RegexQuantifier] } else { - base = RegexRepetition(base, SimpleQuantifier(consume())) + SimpleQuantifier(consume()) } + base = RegexRepetition(base, quantifier) } base } @@ -114,7 +115,7 @@ class RegexParser(pattern: String) { parseEscapedCharacter() case '\u0000' => throw new RegexUnsupportedException( - "cuDF does not support null characters in regular expressions", Some(i)) + "cuDF does not support null characters in regular expressions", Some(pos)) case other => RegexChar(other) } @@ -127,7 +128,7 @@ class RegexParser(pattern: String) { } private def parseCharacterClass(): RegexCharacterClass = { - val start = i + val start = pos val characterClass = RegexCharacterClass(negated = false, characters = ListBuffer()) // loop until the end of the character class or EOF var characterClassComplete = false @@ -139,18 +140,18 @@ class RegexParser(pattern: String) { characterClass.append(ch) case ']' => characterClassComplete = true - case '^' if i == start + 1 => + case '^' if pos == start + 1 => // Negates the character class, causing it to match a single character not listed in // the character class. Only valid immediately after the opening '[' characterClass.negated = true - case '\n' | '\r' | '\t' | '\b' | '\f' | '\007' => + case '\n' | '\r' | '\t' | '\b' | '\f' | '\u0007' => // treat as a literal character and add to the character class characterClass.append(ch) case '\\' => peek() match { case None => throw new RegexUnsupportedException( - s"unexpected EOF while parsing escaped character", Some(i)) + s"unexpected EOF while parsing escaped character", Some(pos)) case Some(ch) => ch match { case '\\' | '^' | '-' | ']' | '+' => @@ -160,7 +161,7 @@ class RegexParser(pattern: String) { } case '\u0000' => throw new RegexUnsupportedException( - "cuDF does not support null characters in regular expressions", Some(i)) + "cuDF does not support null characters in regular expressions", Some(pos)) case _ => // check for range val start = ch @@ -184,7 +185,7 @@ class RegexParser(pattern: String) { } if (!characterClassComplete) { throw new RegexUnsupportedException( - s"unexpected EOF while parsing character class", Some(i)) + s"unexpected EOF while parsing character class", Some(pos)) } characterClass } @@ -200,11 +201,11 @@ class RegexParser(pattern: String) { private def parseQuantifierOrLiteralBrace(): RegexAST = { // assumes that '{' has already been consumed - val start = i + val start = pos def treatAsLiteralBrace() = { // this was not a quantifier, just a literal '{' - i = start + 1 + pos = start + 1 RegexChar('{') } @@ -243,84 +244,112 @@ class RegexParser(pattern: String) { private def parseEscapedCharacter(): RegexAST = { peek() match { case None => - throw new RegexUnsupportedException("escape at end of string", Some(i)) + throw new RegexUnsupportedException("escape at end of string", Some(pos)) case Some(ch) => - consumeExpected(ch) ch match { case 'A' | 'Z' => // BOL / EOL anchors + consumeExpected(ch) RegexEscaped(ch) case 's' | 'S' | 'd' | 'D' | 'w' | 'W' => // meta sequences + consumeExpected(ch) RegexEscaped(ch) case 'B' | 'b' => // word boundaries + consumeExpected(ch) RegexEscaped(ch) case '[' | '\\' | '^' | '$' | '.' | '⎮' | '?' | '*' | '+' | '(' | ')' | '{' | '}' => // escaped metacharacter + consumeExpected(ch) RegexEscaped(ch) case 'x' => + consumeExpected(ch) parseHexDigit - case _ if ch.isDigit => + case _ if Character.isDigit(ch) => parseOctalDigit case other => throw new RegexUnsupportedException( - s"invalid or unsupported escape character '$other'", Some(i - 1)) + s"invalid or unsupported escape character '$other'", Some(pos - 1)) } } } - private def parseHexDigit: RegexAST = { + private def isHexDigit(ch: Char): Boolean = ch.isDigit || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') - def isHexDigit(ch: Char): Boolean = ch.isDigit || - (ch >= 'a' && ch <= 'f') || - (ch >= 'A' && ch <= 'F') + private def parseHexDigit: RegexHexDigit = { + // \xhh The character with hexadecimal value 0xhh + // \x{h...h} The character with hexadecimal value 0xh...h + // (Character.MIN_CODE_POINT <= 0xh...h <= Character.MAX_CODE_POINT) - if (i + 1 < pattern.length - && isHexDigit(pattern.charAt(i)) - && isHexDigit(pattern.charAt(i+1))) { - val hex = pattern.substring(i, i + 2) - i += 2 - RegexHexChar(hex) - } else { - throw new RegexUnsupportedException( - "Invalid hex digit", Some(i)) + val start = pos + while (!eof() && isHexDigit(pattern.charAt(pos))) { + pos += 1 } + val hexDigit = pattern.substring(start, pos) + + if (hexDigit.length < 2) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + val value = Integer.parseInt(hexDigit, 16) + if (value < Character.MIN_CODE_POINT || value > Character.MAX_CODE_POINT) { + throw new RegexUnsupportedException(s"Invalid hex digit: $hexDigit") + } + + RegexHexDigit(hexDigit) } - private def parseOctalDigit = { + private def isOctalDigit(ch: Char): Boolean = ch >= '0' && ch <= '7' - if (i + 2 < pattern.length - && pattern.charAt(i).isDigit - && pattern.charAt(i+1).isDigit - && pattern.charAt(i+2).isDigit) { - val hex = pattern.substring(i, i + 2) - i += 3 - RegexOctalChar(hex) + private def parseOctalDigit: RegexOctalChar = { + // \0n The character with octal value 0n (0 <= n <= 7) + // \0nn The character with octal value 0nn (0 <= n <= 7) + // \0mnn The character with octal value 0mnn (0 <= m <= 3, 0 <= n <= 7) + + def parseOctalDigits(n: Integer): RegexOctalChar = { + val octal = pattern.substring(pos, pos + n) + pos += n + RegexOctalChar(octal) + } + + if (!eof() && isOctalDigit(pattern.charAt(pos))) { + if (pos + 1 < pattern.length && isOctalDigit(pattern.charAt(pos + 1))) { + if (pos + 2 < pattern.length && isOctalDigit(pattern.charAt(pos + 2)) + && pattern.charAt(pos) <= '3') { + parseOctalDigits(3) + } else { + parseOctalDigits(2) + } + } else { + parseOctalDigits(1) + } } else { throw new RegexUnsupportedException( - "Invalid octal digit", Some(i)) + "Invalid octal digit", Some(pos)) } } /** Determine if we are at the end of the input */ - private def eof(): Boolean = i == pattern.length + private def eof(): Boolean = pos == pattern.length /** Advance the index by one */ private def skip(): Unit = { if (eof()) { - throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) } - i += 1 + pos += 1 } /** Get the next character and advance the index by one */ private def consume(): Char = { if (eof()) { - throw new RegexUnsupportedException("Unexpected EOF", Some(i)) + throw new RegexUnsupportedException("Unexpected EOF", Some(pos)) } else { - i += 1 - pattern.charAt(i - 1) + pos += 1 + pattern.charAt(pos - 1) } } @@ -329,7 +358,7 @@ class RegexParser(pattern: String) { val consumed = consume() if (consumed != expected) { throw new RegexUnsupportedException( - s"Expected '$expected' but found '$consumed'", Some(i-1)) + s"Expected '$expected' but found '$consumed'", Some(pos-1)) } consumed } @@ -339,19 +368,19 @@ class RegexParser(pattern: String) { if (eof()) { None } else { - Some(pattern.charAt(i)) + Some(pattern.charAt(pos)) } } private def consumeInt(): Option[Int] = { - val start = i + val start = pos while (!eof() && peek().exists(_.isDigit)) { skip() } - if (start == i) { + if (start == pos) { None } else { - Some(pattern.substring(start, i).toInt) + Some(pattern.substring(start, pos).toInt) } } @@ -377,6 +406,11 @@ class CudfRegexTranspiler { private def validate(regex: RegexAST): Unit = { regex match { + case RegexOctalChar(_) => + // cuDF produced different results compared to Spark in some cases + // example: "a\141|.$" + throw new RegexUnsupportedException( + s"cuDF does not support octal digits consistently with Spark") case RegexGroup(RegexSequence(parts)) if parts.isEmpty => // example: "()" throw new RegexUnsupportedException("cuDF does not support empty groups") @@ -521,7 +555,7 @@ sealed case class QuantifierVariableLength(minLength: Int, maxLength: Option[Int sealed trait RegexCharacterClassComponent extends RegexAST -sealed case class RegexHexChar(a: String) extends RegexCharacterClassComponent { +sealed case class RegexHexDigit(a: String) extends RegexCharacterClassComponent { override def children(): Seq[RegexAST] = Seq.empty override def toRegexString: String = s"\\x$a" } @@ -536,6 +570,11 @@ sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent { override def toRegexString: String = s"$a" } +sealed case class RegexUnicodeChar(a: String) extends RegexCharacterClassComponent { + override def children(): Seq[RegexAST] = Seq.empty + override def toRegexString: String = s"\\u$a" +} + sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{ override def children(): Seq[RegexAST] = Seq.empty override def toRegexString: String = s"\\$a" diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index 37fcf9a9f1f..e56d445f123 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -19,7 +19,6 @@ import scala.collection.mutable.ListBuffer import org.scalatest.FunSuite - class RegularExpressionParserSuite extends FunSuite { test("simple quantifier") { @@ -65,6 +64,27 @@ class RegularExpressionParserSuite extends FunSuite { RegexCharacterRange('A', 'Z')))))) } + test("hex digit") { + assert(parse(raw"\xFF") === + RegexSequence(ListBuffer(RegexHexDigit("FF")))) + } + + test("octal digit") { + val digits = Seq("1", "76", "123", "377") + for (digit <- digits) { + assert(parse(raw"\$digit") === + RegexSequence(ListBuffer(RegexOctalChar(digit)))) + } + + // parsing of the octal digit should terminate after parsing "\1" + assert(parse(raw"\18") === + RegexSequence(ListBuffer(RegexOctalChar("1"), RegexChar('8')))) + + // parsing of the octal digit should terminate after parsing "\47" + assert(parse(raw"\477") === + RegexSequence(ListBuffer(RegexOctalChar("47"), RegexChar('7')))) + } + private def parse(pattern: String): RegexAST = { new RegexParser(pattern).parse() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 213296e86db..d013de03273 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -99,7 +99,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { val pattern = "1." // '.' matches '\r' on GPU but not on CPU assertCpuGpuContainsMatches(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2")) + } + ignore("known issue - octal digit") { + val pattern = "a\\141|.$" // using hex works fine e.g. "a\\x61|.$" + assertCpuGpuContainsMatches(Seq(pattern), Seq("] b[")) } test("character class with ranges") { @@ -224,8 +228,8 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { private def assertCpuGpuContainsMatches(javaPatterns: Seq[String], input: Seq[String]) = { for (javaPattern <- javaPatterns) { - val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) val cpu = cpuContains(javaPattern, input) + val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) val gpu = gpuContains(cudfPattern, input) for (i <- input.indices) { if (cpu(i) != gpu(i)) { From 9967f05a3e5f3564130606ecec959b978364fa41 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 20:07:45 -0700 Subject: [PATCH 10/14] make some regex validation less specific --- .../com/nvidia/spark/rapids/RegexParser.scala | 18 +++++++----------- .../RegularExpressionTranspilerSuite.scala | 7 +++++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 7ce2374c15e..e461f15fdb9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -411,9 +411,13 @@ class CudfRegexTranspiler { // example: "a\141|.$" throw new RegexUnsupportedException( s"cuDF does not support octal digits consistently with Spark") - case RegexGroup(RegexSequence(parts)) if parts.isEmpty => - // example: "()" - throw new RegexUnsupportedException("cuDF does not support empty groups") + case RegexEscaped(ch) if ch == 'b' || ch == 'B' => + // example: "a\Bb" + // this needs further analysis to determine why words boundaries behave + // differently between Java and cuDF + throw new RegexUnsupportedException("word boundaries are not supported") + case RegexSequence(parts) if parts.isEmpty => + throw new RegexUnsupportedException("empty sequence not supported") case RegexRepetition(RegexEscaped(_), _) => // example: "\B?" throw new RegexUnsupportedException(nothingToRepeat) @@ -469,14 +473,6 @@ class CudfRegexTranspiler { // falling back to CPU throw new RegexUnsupportedException(nothingToRepeat) } - parts.foreach { - case RegexEscaped(ch) if ch == 'b' || ch == 'B' => - // example: "a\Bb" - // this needs further analysis to determine why words boundaries behave - // differently between Java and cuDF - throw new RegexUnsupportedException("word boundaries are not supported") - case _ => - } case RegexCharacterClass(_, characters) => characters.foreach { case RegexChar(ch) if ch == '[' || ch == ']' => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index d013de03273..38e938337c6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -56,8 +56,11 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { ) } - test("cuDF does not support empty groups") { - assertUnsupported("a()?", "cuDF does not support empty groups") + test("cuDF does not support empty sequence") { + val patterns = Seq("", "a|", "()") + patterns.foreach(pattern => + assertUnsupported(pattern, "empty sequence not supported") + ) } test("cuDF does not support quantifier syntax when not quantifying anything") { From 8e63f54dd82cead3bbfd14c0844207a57c7ca5b2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 20:29:25 -0700 Subject: [PATCH 11/14] remove redundant check --- .../src/main/scala/com/nvidia/spark/rapids/RegexParser.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index e461f15fdb9..d5a14c8a18c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -417,6 +417,7 @@ class CudfRegexTranspiler { // differently between Java and cuDF throw new RegexUnsupportedException("word boundaries are not supported") case RegexSequence(parts) if parts.isEmpty => + // examples: "", "()", "a|", "|b" throw new RegexUnsupportedException("empty sequence not supported") case RegexRepetition(RegexEscaped(_), _) => // example: "\B?" @@ -461,10 +462,6 @@ class CudfRegexTranspiler { } case RegexSequence(parts) => - if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { - // examples: "a|", "|b" - throw new RegexUnsupportedException(nothingToRepeat) - } if (isRegexChar(parts.head, '{')) { // example: "{" // cuDF would treat this as a quantifier even though in this From 8d32bdda22a724b5dc757f25a0694f935909ee0a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 20:47:08 -0700 Subject: [PATCH 12/14] add parser test for complex expression --- .../rapids/RegularExpressionParserSuite.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index e56d445f123..e06fdede065 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -85,6 +85,80 @@ class RegularExpressionParserSuite extends FunSuite { RegexSequence(ListBuffer(RegexOctalChar("47"), RegexChar('7')))) } + test("complex expression") { + val ast = parse( + "^" + // start of line + "[+\\-]?" + // optional + or - at start of string + "(" + + "(" + + "(" + + "([0-9]+)|" + // digits, OR + "([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR + "([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing + ")" + + "([eE][+\\-]?[0-9]+)?" + // exponent + "[fFdD]?" + // floating-point designator + ")" + + "|Inf" + // Infinity + "|[nN][aA][nN]" + // NaN + ")" + + "$" // end of line + ) + assert(ast === + RegexSequence(ListBuffer( + RegexChar('^'), + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexChar('+'), RegexEscaped('-'))), SimpleQuantifier('?')), + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexGroup(RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + RegexChar('|'), + RegexGroup(RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')), + RegexEscaped('.'), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + RegexChar('|'), + RegexGroup(RegexSequence(ListBuffer(RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+')), + RegexEscaped('.'), + RegexRepetition( + RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('*')))))))), + RegexRepetition(RegexGroup(RegexSequence(ListBuffer( + RegexCharacterClass(negated = false, + ListBuffer(RegexChar('e'), + RegexChar('E'))), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexChar('+'), RegexEscaped('-'))),SimpleQuantifier('?')), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexCharacterRange('0', '9'))),SimpleQuantifier('+'))))), + SimpleQuantifier('?')), + RegexRepetition(RegexCharacterClass(negated = false, + ListBuffer(RegexChar('f'), RegexChar('F'), + RegexChar('d'), RegexChar('D'))),SimpleQuantifier('?'))))), + RegexChar('|'), RegexChar('I'), RegexChar('n'), RegexChar('f'), RegexChar('|'), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N'))), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('a'), RegexChar('A'))), + RegexCharacterClass(negated = false, ListBuffer(RegexChar('n'), RegexChar('N'))))) + ), + RegexChar('$')))) + } + + /* +Expected :RegexSequence(ListBuffer(RegexChar(^), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(*)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(*)))))))), RegexRepetition(RegexGroup(RegexSequence(ListBuffer(RegexCharacterClass(false,ListBuffer(RegexChar(e), RegexChar(E))), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(f), RegexChar(F), RegexChar(d), RegexChar(D))),SimpleQuantifier(?))))), RegexChar(|), RegexChar(I), RegexChar(n), RegexChar(f), RegexChar(|), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N))), RegexCharacterClass(false,ListBuffer(RegexChar(a), RegexChar(A))), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N)))))), RegexChar($))) +Actual :RegexSequence(ListBuffer(RegexChar(^), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(*)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(*)))))))), RegexRepetition(RegexGroup(RegexSequence(ListBuffer(RegexCharacterClass(false,ListBuffer(RegexChar(e), RegexChar(E))), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(f), RegexChar(F), RegexChar(d), RegexChar(D))),SimpleQuantifier(?))))), RegexChar(|), RegexChar(I), RegexChar(n), RegexChar(f), RegexChar(|), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N))), RegexCharacterClass(false,ListBuffer(RegexChar(a), RegexChar(A))), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N)))))), RegexChar($))) + + */ + private def parse(pattern: String): RegexAST = { new RegexParser(pattern).parse() } From 65bad0f0112c12ea38dff6361abcb2355ba442cd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 8 Nov 2021 20:47:41 -0700 Subject: [PATCH 13/14] remove comment --- .../nvidia/spark/rapids/RegularExpressionParserSuite.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index e06fdede065..deea53e7900 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -153,12 +153,6 @@ class RegularExpressionParserSuite extends FunSuite { RegexChar('$')))) } - /* -Expected :RegexSequence(ListBuffer(RegexChar(^), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(*)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(*)))))))), RegexRepetition(RegexGroup(RegexSequence(ListBuffer(RegexCharacterClass(false,ListBuffer(RegexChar(e), RegexChar(E))), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange( , ))),SimpleQuantifier(+))))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(f), RegexChar(F), RegexChar(d), RegexChar(D))),SimpleQuantifier(?))))), RegexChar(|), RegexChar(I), RegexChar(n), RegexChar(f), RegexChar(|), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N))), RegexCharacterClass(false,ListBuffer(RegexChar(a), RegexChar(A))), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N)))))), RegexChar($))) -Actual :RegexSequence(ListBuffer(RegexChar(^), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(*)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))), RegexChar(|), RegexGroup(RegexSequence(ListBuffer(RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+)), RegexEscaped(.), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(*)))))))), RegexRepetition(RegexGroup(RegexSequence(ListBuffer(RegexCharacterClass(false,ListBuffer(RegexChar(e), RegexChar(E))), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(+), RegexEscaped(-))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexCharacterRange(0,9))),SimpleQuantifier(+))))),SimpleQuantifier(?)), RegexRepetition(RegexCharacterClass(false,ListBuffer(RegexChar(f), RegexChar(F), RegexChar(d), RegexChar(D))),SimpleQuantifier(?))))), RegexChar(|), RegexChar(I), RegexChar(n), RegexChar(f), RegexChar(|), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N))), RegexCharacterClass(false,ListBuffer(RegexChar(a), RegexChar(A))), RegexCharacterClass(false,ListBuffer(RegexChar(n), RegexChar(N)))))), RegexChar($))) - - */ - private def parse(pattern: String): RegexAST = { new RegexParser(pattern).parse() } From 62992b1968d798ca472f8766aaa86c235d40f514 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 9 Nov 2021 09:49:39 -0700 Subject: [PATCH 14/14] revert removing check that was not redundant after all --- .../main/scala/com/nvidia/spark/rapids/RegexParser.scala | 4 ++++ .../spark/rapids/RegularExpressionTranspilerSuite.scala | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index d5a14c8a18c..bfc12b865fc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -462,6 +462,10 @@ class CudfRegexTranspiler { } case RegexSequence(parts) => + if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) { + // examples: "a|", "|b" + throw new RegexUnsupportedException(nothingToRepeat) + } if (isRegexChar(parts.head, '{')) { // example: "{" // cuDF would treat this as a quantifier even though in this diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 38e938337c6..c7f13ee832f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -233,7 +233,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { for (javaPattern <- javaPatterns) { val cpu = cpuContains(javaPattern, input) val cudfPattern = new CudfRegexTranspiler().transpile(javaPattern) - val gpu = gpuContains(cudfPattern, input) + val gpu = try { + gpuContains(cudfPattern, input) + } catch { + case e: CudfException => + fail(s"cuDF failed to compile pattern: $cudfPattern", e) + } for (i <- input.indices) { if (cpu(i) != gpu(i)) { fail(s"javaPattern=${toReadableString(javaPattern)}, " +