diff --git a/CHANGELOG.md b/CHANGELOG.md index 33ea4d433..b5c529997 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ * [Fix] Removed `WITH` when a snippet does not have a dependency (#657) * [Fix] Used display module when generating CTE (#649) * [Fix] Adding `--with` back because of issues with sqlglot query parser (#684) +* [Fix] Improving << parsing logic (#610) ## 0.7.9 (2023-06-19) diff --git a/src/sql/parse.py b/src/sql/parse.py index 05bf862e8..b0f15de14 100644 --- a/src/sql/parse.py +++ b/src/sql/parse.py @@ -37,7 +37,6 @@ def parse(cell, config): We're grandfathering the connection string and `<<` operator in. """ - result = { "connection": "", "sql": "", @@ -53,37 +52,21 @@ def parse(cell, config): if len(pieces) == 1: return result cell = pieces[1] - # handle no space situation around = - if pieces[0].endswith("=<<"): - result["result_var"] = pieces[0][:-3] - result["return_result_var"] = True - cell = pieces[1] - pieces = cell.split(None, 2) - # handle flexible spacing around << - if len(pieces) > 1 and pieces[1] == "<<": - if pieces[0].endswith("="): - result["result_var"] = pieces[0][:-1] - result["return_result_var"] = True - else: - result["result_var"] = pieces[0] + pointer = cell.find("<<") + if pointer != -1: + left = cell[:pointer].replace(" ", "").replace("\n", "") + right = cell[pointer + 2 :].strip(" ") - if len(pieces) == 2: - return result - cell = pieces[2] - # handle flexible spacing around =<< - elif len(pieces) > 1 and ( - (pieces[1] == "=<<") or (pieces[1] == "=" and pieces[2].startswith("<<")) - ): - result["result_var"] = pieces[0] - result["return_result_var"] = True - if pieces[1] == "=<<": - cell = pieces[2] + if "=" in left: + result["result_var"] = left[:-1] + result["return_result_var"] = True else: - pieces = cell.split(None, 3) - cell = pieces[3] + result["result_var"] = left - result["sql"] = cell + result["sql"] = right + else: + result["sql"] = cell return result diff --git a/src/tests/test_parse.py b/src/tests/test_parse.py index a72fb9920..d1d594d9a 100644 --- a/src/tests/test_parse.py +++ b/src/tests/test_parse.py @@ -88,16 +88,41 @@ def test_parse_shovel_operator(): "dest =<< SELECT * FROM work", "dest = << SELECT * FROM work", "dest=<< SELECT * FROM work", + "dest=<