From fd4b73803c8154edaa049dd0ae406da4e12224b8 Mon Sep 17 00:00:00 2001 From: "hai.yang" Date: Fri, 17 Jan 2025 10:11:05 +0800 Subject: [PATCH] =?UTF-8?q?goinception=E5=8E=BB=E9=87=8D=E8=AF=AD=E6=B3=95?= =?UTF-8?q?=E8=A7=A3=E6=9E=90=E5=90=8E=E7=9A=84=E8=A1=A8=EF=BC=8C=E5=8E=BB?= =?UTF-8?q?=E6=8E=89with=E5=90=8E=E7=9A=84=E8=99=9A=E6=8B=9F=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sql/engines/goinception.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index eab4f0fc0e..710c81120d 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -313,6 +313,14 @@ def get_table_ref(query_tree, db_name=None): 小子树才是效率较高的算法,但是就这样吧,反正它能运行 :) """ table_ref = [] + temporary_tables = set() # 用于存储临时表名 + + # 首先识别所有的临时表名 + if "With" in query_tree: + logger.warning(query_tree["With"]) + with_definitions = query_tree["With"].get("CTEs", []) # 假设临时表定义在CTEs键下 + for definition in with_definitions: + temporary_tables.add(definition["Name"]["O"]) # 获取临时表的名称 find_queue = [query_tree] for tree in find_queue: @@ -326,20 +334,16 @@ def get_table_ref(query_tree, db_name=None): else: snodes = tree.find_max_tree("Source") if snodes: - table_ref.extend( - [ - { - "schema": snode["Source"].get("Schema", {}).get("O") - or db_name, - "name": snode["Source"].get("Name", {}).get("O", ""), - } - for snode in snodes - ] - ) + for snode in snodes: + schema_name = snode["Source"].get("Schema", {}).get("O") or db_name + table_name = snode["Source"].get("Name", {}).get("O", "") + # 检查表名是否为临时表,如果不是则添加到结果中 + if table_name not in temporary_tables: + table_ref.append({"schema": schema_name, "name": table_name}) # assert: source node must exists if table_refs node exists. # else: # raise Exception("GoInception Error: not found source node") - return table_ref + return remove_duplicates(table_ref) def close(self): if self.conn: @@ -380,3 +384,13 @@ def get_session_variables(instance): for k, v in variables.items(): set_session_sql += f"inception set session {k} = '{v}';\n" return variables, set_session_sql + +def remove_duplicates(table_list): + unique_tables = [] + seen = set() + for table in table_list: + identifier = (table['schema'], table['name']) + if identifier not in seen: + seen.add(identifier) + unique_tables.append(table) + return unique_tables \ No newline at end of file