diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index eab4f0fc0e..710c81120d 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -313,6 +313,14 @@ def get_table_ref(query_tree, db_name=None): 小子树才是效率较高的算法,但是就这样吧,反正它能运行 :) """ table_ref = [] + temporary_tables = set() # 用于存储临时表名 + + # 首先识别所有的临时表名 + if "With" in query_tree: + logger.warning(query_tree["With"]) + with_definitions = query_tree["With"].get("CTEs", []) # 假设临时表定义在CTEs键下 + for definition in with_definitions: + temporary_tables.add(definition["Name"]["O"]) # 获取临时表的名称 find_queue = [query_tree] for tree in find_queue: @@ -326,20 +334,16 @@ def get_table_ref(query_tree, db_name=None): else: snodes = tree.find_max_tree("Source") if snodes: - table_ref.extend( - [ - { - "schema": snode["Source"].get("Schema", {}).get("O") - or db_name, - "name": snode["Source"].get("Name", {}).get("O", ""), - } - for snode in snodes - ] - ) + for snode in snodes: + schema_name = snode["Source"].get("Schema", {}).get("O") or db_name + table_name = snode["Source"].get("Name", {}).get("O", "") + # 检查表名是否为临时表,如果不是则添加到结果中 + if table_name not in temporary_tables: + table_ref.append({"schema": schema_name, "name": table_name}) # assert: source node must exists if table_refs node exists. # else: # raise Exception("GoInception Error: not found source node") - return table_ref + return remove_duplicates(table_ref) def close(self): if self.conn: @@ -380,3 +384,13 @@ def get_session_variables(instance): for k, v in variables.items(): set_session_sql += f"inception set session {k} = '{v}';\n" return variables, set_session_sql + +def remove_duplicates(table_list): + unique_tables = [] + seen = set() + for table in table_list: + identifier = (table['schema'], table['name']) + if identifier not in seen: + seen.add(identifier) + unique_tables.append(table) + return unique_tables \ No newline at end of file