From 651b8bcd79ff2be9d8e99194efed9654d5f5213f Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 15:38:27 +0200 Subject: [PATCH 1/4] Use dialect's quoting directly, not via compiler --- data_diff/queries/ast_classes.py | 14 +++++++------- data_diff/queries/compiler.py | 3 --- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 70cb355f..954783de 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -74,7 +74,7 @@ class Alias(ExprNode): name: str def compile(self, c: Compiler) -> str: - return f"{c.compile(self.expr)} AS {c.quote(self.name)}" + return f"{c.compile(self.expr)} AS {c.dialect.quote(self.name)}" @property def type(self): @@ -408,14 +408,14 @@ def compile(self, c: Compiler) -> str: t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table ] if not aliases: - return c.quote(self.name) + return c.dialect.quote(self.name) elif len(aliases) > 1: raise CompileError(f"Too many aliases for column {self.name}") (alias,) = aliases - return f"{c.quote(alias.name)}.{c.quote(self.name)}" + return f"{c.dialect.quote(alias.name)}.{c.dialect.quote(self.name)}" - return c.quote(self.name) + return c.dialect.quote(self.name) @dataclass @@ -429,7 +429,7 @@ def source_table(self) -> Self: def compile(self, c: Compiler) -> str: path = self.path # c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) + return ".".join(map(c.dialect.quote, path)) # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): @@ -515,7 +515,7 @@ class TableAlias(ExprNode, ITable): name: str def compile(self, c: Compiler) -> str: - return f"{c.compile(self.source_table)} {c.quote(self.name)}" + return f"{c.compile(self.source_table)} {c.dialect.quote(self.name)}" @dataclass @@ -989,7 +989,7 @@ def compile(self, c: Compiler) -> str: else: expr = c.compile(self.expr) - columns = "(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else "" + columns = "(%s)" % ", ".join(map(c.dialect.quote, self.columns)) if self.columns is not None else "" return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index e6246236..a08d85ea 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -83,6 +83,3 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: def add_table_context(self, *tables: Sequence, **kw) -> Self: return self.replace(_table_context=self._table_context + list(tables), **kw) - - def quote(self, s: str): - return self.dialect.quote(s) From bf4eeea5d0876457f1b4a541f597e28f7bfbeb5f Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 15:04:01 +0200 Subject: [PATCH 2/4] Move table name parsing to dialects, where they semantically belong --- data_diff/__init__.py | 8 +++++--- data_diff/__main__.py | 9 +++++---- data_diff/abcs/database_types.py | 8 ++++---- data_diff/databases/base.py | 6 +++--- data_diff/databases/bigquery.py | 8 ++++---- data_diff/databases/databricks.py | 8 ++++---- data_diff/queries/compiler.py | 3 ++- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 60c79b10..4ae223fb 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,6 +1,7 @@ from typing import Sequence, Tuple, Iterator, Optional, Union from data_diff.abcs.database_types import DbTime, DbPath +from data_diff.databases import Database from data_diff.tracking import disable_tracking from data_diff.databases._connect import connect from data_diff.diff_tables import Algorithm @@ -31,10 +32,10 @@ def connect_to_table( if isinstance(key_columns, str): key_columns = (key_columns,) - db = connect(db_info, thread_count=thread_count) + db: Database = connect(db_info, thread_count=thread_count) if isinstance(table_name, str): - table_name = db.parse_table_name(table_name) + table_name = db.dialect.parse_table_name(table_name) return TableSegment(db, table_name, key_columns, **kwargs) @@ -161,7 +162,8 @@ def diff_tables( ) elif algorithm == Algorithm.JOINDIFF: if isinstance(materialize_to_table, str): - materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table)) + table_name = eval_name_template(materialize_to_table) + materialize_to_table = table1.database.dialect.parse_table_name(table_name) differ = JoinDiffer( threaded=threaded, max_threadpool_size=max_threadpool_size, diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 77dc7fb6..0e5255e6 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -6,12 +6,13 @@ import json import logging from itertools import islice -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import rich from rich.logging import RichHandler import click +from data_diff import Database from data_diff.schema import create_schema from data_diff.queries.api import current_timestamp @@ -425,7 +426,7 @@ def _data_diff( logging.error(f"Error while parsing age expression: {e}") return - dbs = db1, db2 + dbs: Tuple[Database, Database] = db1, db2 if interactive: for db in dbs: @@ -444,7 +445,7 @@ def _data_diff( materialize_all_rows=materialize_all_rows, table_write_limit=table_write_limit, materialize_to_table=materialize_to_table - and db1.parse_table_name(eval_name_template(materialize_to_table)), + and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)), ) else: assert algorithm == Algorithm.HASHDIFF @@ -456,7 +457,7 @@ def _data_diff( ) table_names = table1, table2 - table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] + table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)] schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) schema1, schema2 = schemas = [ diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 82ec8352..c811ace5 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -176,6 +176,10 @@ class UnknownColType(ColType): class AbstractDialect(ABC): """Dialect-dependent query expressions""" + @abstractmethod + def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" + @property @abstractmethod def name(self) -> str: @@ -319,10 +323,6 @@ def _process_table_schema( """ - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" - @abstractmethod def close(self): "Close connection(s) to the database instance. Querying will stop functioning." diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index a89ab74e..082e9815 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -156,6 +156,9 @@ class BaseDialect(AbstractDialect): PLACEHOLDER_TABLE = None # Used for Oracle + def parse_table_name(self, name: str) -> DbPath: + return parse_table_name(name) + def offset_limit( self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None ) -> str: @@ -518,9 +521,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") - def parse_table_name(self, name: str) -> DbPath: - return parse_table_name(name) - def _query_cursor(self, c, sql_code: str) -> QueryResult: assert isinstance(sql_code, str), sql_code try: diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 5925234f..feb98bde 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -212,6 +212,10 @@ def to_comparable(self, value: str, coltype: ColType) -> str: def set_timezone_to_utc(self) -> str: raise NotImplementedError() + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in path if i is not None) + class BigQuery(Database): CONNECT_URI_HELP = "bigquery:///" @@ -288,10 +292,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" ) - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 1b8aa33a..67d0528d 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -94,6 +94,10 @@ def _convert_db_precision_to_digits(self, p: int) -> int: def set_timezone_to_utc(self) -> str: return "SET TIME ZONE 'UTC'" + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in path if i is not None) + class Databricks(ThreadedDatabase): dialect = Dialect() @@ -178,10 +182,6 @@ def _process_table_schema( self._refine_coltypes(path, col_dict, where) return col_dict - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index a08d85ea..14eb0b77 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -79,7 +79,8 @@ def new_unique_name(self, prefix="tmp"): def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 - return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") + table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: return self.replace(_table_context=self._table_context + list(tables), **kw) From b8cf4827d02fb6fd0491c00af56343260699224b Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 16:05:28 +0200 Subject: [PATCH 3/4] Remove compiler's unused params --- data_diff/queries/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 14eb0b77..b0df04eb 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -26,7 +26,7 @@ class Root: @dataclass class Compiler(AbstractCompiler): database: AbstractDatabase - params: dict = field(default_factory=dict) + in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag From 795bb0ec2a6ad95104a6a37f003a33b222f94c42 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 25 Sep 2023 15:34:47 +0200 Subject: [PATCH 4/4] Compile all AST elements always via dialects, never directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The root authority on how to do the SQL syntax properly is the dialect — not the AST element itself. AST tree must only carry the intentions of what we want to execute, but not how it should/could be executed. This makes the AST truly independent of dialects and databases, allowing us to: - Focus on the main logic regardless of the SQL capabilities. - Create custom database connectors without involving AST changes every time. Before the change, adding a new database connector with unusual syntax would often require changing the AST elements to direct to `compiler.dialect.some_method()` — with only a subset of SQL being described in dialects. The other rather arbitrary part of SQL syntax was hard-coded in AST elements and could not be easily overridden without such changes. After the change, all the SQL logic is concentrated in one hierarchy of dialects, mostly in one base class. --- data_diff/abcs/compiler.py | 12 +- data_diff/abcs/database_types.py | 7 +- data_diff/bound_exprs.py | 5 - data_diff/databases/base.py | 449 ++++++++++++++++++++++++++++++- data_diff/joindiff_tables.py | 10 +- data_diff/queries/ast_classes.py | 273 +------------------ data_diff/queries/compiler.py | 55 +--- data_diff/queries/extras.py | 41 +-- 8 files changed, 475 insertions(+), 377 deletions(-) diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py index 72fd7578..4a847d05 100644 --- a/data_diff/abcs/compiler.py +++ b/data_diff/abcs/compiler.py @@ -1,15 +1,9 @@ -from typing import Any, Dict -from abc import ABC, abstractmethod +from abc import ABC class AbstractCompiler(ABC): - @abstractmethod - def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: - ... + pass class Compilable(ABC): - # TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T - @abstractmethod - def compile(self, c: AbstractCompiler) -> str: - ... + pass diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index c811ace5..a679db67 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,11 +1,12 @@ import decimal from abc import ABC, abstractmethod -from typing import Sequence, Optional, Tuple, Type, Union, Dict, List +from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime from runtype import dataclass from typing_extensions import Self +from data_diff.abcs.compiler import AbstractCompiler from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -176,6 +177,10 @@ class UnknownColType(ColType): class AbstractDialect(ABC): """Dialect-dependent query expressions""" + @abstractmethod + def compile(self, compiler: AbstractCompiler, elem, params=None) -> str: + raise NotImplementedError + @abstractmethod def parse_table_name(self, name: str) -> DbPath: "Parse the given table name into a DbPath" diff --git a/data_diff/bound_exprs.py b/data_diff/bound_exprs.py index 1742b74c..4b53846d 100644 --- a/data_diff/bound_exprs.py +++ b/data_diff/bound_exprs.py @@ -8,7 +8,6 @@ from typing_extensions import Self from data_diff.abcs.database_types import AbstractDatabase -from data_diff.abcs.compiler import AbstractCompiler from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable from data_diff.queries.api import table from data_diff.schema import create_schema @@ -37,10 +36,6 @@ def query(self, res_type=list): def type(self): return self.node.type - def compile(self, c: AbstractCompiler) -> str: - assert c.database is self.database - return self.node.compile(c) - def bind_node(node, database): return BoundNode(database, node) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 082e9815..dc43d8d7 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,3 +1,4 @@ +import functools from datetime import datetime import math import sys @@ -9,13 +10,25 @@ from abc import abstractmethod from uuid import UUID import decimal +import contextvars from runtype import dataclass from typing_extensions import Self -from data_diff.utils import is_uuid, safezip +from data_diff.queries.compiler import CompileError +from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString +from data_diff.utils import ArithString, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this -from data_diff.queries.ast_classes import Random +from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \ + CreateTable, Cte, \ + CurrentTimestamp, DropTable, Func, \ + GroupBy, \ + In, InsertToTable, IsDistinctFrom, \ + Join, \ + Param, \ + Random, \ + Root, TableAlias, TableOp, TablePath, TestRegex, \ + TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn from data_diff.abcs.database_types import ( AbstractDatabase, Array, @@ -39,16 +52,17 @@ Boolean, JSON, ) -from data_diff.abcs.mixins import Compilable +from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel, Compilable from data_diff.abcs.mixins import ( AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue, AbstractMixin_OptimizerHints, ) -from data_diff.bound_exprs import bound_table +from data_diff.bound_exprs import BoundNode, bound_table logger = logging.getLogger("database") +cv_params = contextvars.ContextVar("params") def parse_table_name(t): @@ -98,7 +112,7 @@ def __init__(self, compiler: Compiler, gen: Generator): def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: - sql = self.compiler.compile(q) + sql = self.compiler.database.dialect.compile(self.compiler, q) logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) try: try: @@ -159,6 +173,424 @@ class BaseDialect(AbstractDialect): def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) + def compile(self, compiler: Compiler, elem, params=None) -> str: + if params: + cv_params.set(params) + + if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root): + from data_diff.queries.ast_classes import Select + + elem = Select(columns=[elem]) + + res = self._compile(compiler, elem) + if compiler.root and compiler._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items()) + compiler._subqueries.clear() + return f"WITH {subq}\n{res}" + return res + + def _compile(self, compiler: Compiler, elem) -> str: + if elem is None: + return "NULL" + elif isinstance(elem, Compilable): + return self.render_compilable(compiler.replace(root=False), elem) + elif isinstance(elem, str): + return f"'{elem}'" + elif isinstance(elem, (int, float)): + return str(elem) + elif isinstance(elem, datetime): + return self.timestamp_value(elem) + elif isinstance(elem, bytes): + return f"b'{elem.decode()}'" + elif isinstance(elem, ArithString): + return f"'{elem}'" + assert False, elem + + def render_compilable(self, c: Compiler, elem: Compilable) -> str: + # All ifs are only for better code navigation, IDE usage detection, and type checking. + # The last catch-all would render them anyway — it is a typical "visitor" pattern. + if isinstance(elem, Column): + return self.render_column(c, elem) + elif isinstance(elem, Cte): + return self.render_cte(c, elem) + elif isinstance(elem, Commit): + return self.render_commit(c, elem) + elif isinstance(elem, Param): + return self.render_param(c, elem) + elif isinstance(elem, NormalizeAsString): + return self.render_normalizeasstring(c, elem) + elif isinstance(elem, ApplyFuncAndNormalizeAsString): + return self.render_applyfuncandnormalizeasstring(c, elem) + elif isinstance(elem, Checksum): + return self.render_checksum(c, elem) + elif isinstance(elem, Concat): + return self.render_concat(c, elem) + elif isinstance(elem, TestRegex): + return self.render_testregex(c, elem) + elif isinstance(elem, Func): + return self.render_func(c, elem) + elif isinstance(elem, WhenThen): + return self.render_whenthen(c, elem) + elif isinstance(elem, CaseWhen): + return self.render_casewhen(c, elem) + elif isinstance(elem, IsDistinctFrom): + return self.render_isdistinctfrom(c, elem) + elif isinstance(elem, UnaryOp): + return self.render_unaryop(c, elem) + elif isinstance(elem, BinOp): + return self.render_binop(c, elem) + elif isinstance(elem, TablePath): + return self.render_tablepath(c, elem) + elif isinstance(elem, TableAlias): + return self.render_tablealias(c, elem) + elif isinstance(elem, TableOp): + return self.render_tableop(c, elem) + elif isinstance(elem, Select): + return self.render_select(c, elem) + elif isinstance(elem, Join): + return self.render_join(c, elem) + elif isinstance(elem, GroupBy): + return self.render_groupby(c, elem) + elif isinstance(elem, Count): + return self.render_count(c, elem) + elif isinstance(elem, Alias): + return self.render_alias(c, elem) + elif isinstance(elem, In): + return self.render_in(c, elem) + elif isinstance(elem, Cast): + return self.render_cast(c, elem) + elif isinstance(elem, Random): + return self.render_random(c, elem) + elif isinstance(elem, Explain): + return self.render_explain(c, elem) + elif isinstance(elem, CurrentTimestamp): + return self.render_currenttimestamp(c, elem) + elif isinstance(elem, TimeTravel): + return self.render_timetravel(c, elem) + elif isinstance(elem, CreateTable): + return self.render_createtable(c, elem) + elif isinstance(elem, DropTable): + return self.render_droptable(c, elem) + elif isinstance(elem, TruncateTable): + return self.render_truncatetable(c, elem) + elif isinstance(elem, InsertToTable): + return self.render_inserttotable(c, elem) + elif isinstance(elem, Code): + return self.render_code(c, elem) + elif isinstance(elem, BoundNode): + return self.render_boundnode(c, elem) + elif isinstance(elem, _ResolveColumn): + return self.render__resolvecolumn(c, elem) + + method_name = f"render_{elem.__class__.__name__.lower()}" + method = getattr(self, method_name, None) + if method is not None: + return method(c, elem) + else: + raise RuntimeError(f"Cannot render AST of type {elem.__class__}") + # return elem.compile(compiler.replace(root=False)) + + def render_column(self, c: Compiler, elem: Column) -> str: + if c._table_context: + if len(c._table_context) > 1: + aliases = [ + t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table + ] + if not aliases: + return self.quote(elem.name) + elif len(aliases) > 1: + raise CompileError(f"Too many aliases for column {elem.name}") + (alias,) = aliases + + return f"{self.quote(alias.name)}.{self.quote(elem.name)}" + + return self.quote(elem.name) + + def render_cte(self, parent_c: Compiler, elem: Cte) -> str: + c: Compiler = parent_c.replace(_table_context=[], in_select=False) + compiled = self.compile(c, elem.source_table) + + name = elem.name or parent_c.new_unique_name() + name_params = f"{name}({', '.join(elem.params)})" if elem.params else name + parent_c._subqueries[name_params] = compiled + + return name + + def render_commit(self, c: Compiler, elem: Commit) -> str: + return "COMMIT" if not c.database.is_autocommit else SKIP + + def render_param(self, c: Compiler, elem: Param) -> str: + params = cv_params.get() + return self._compile(c, params[elem.name]) + + def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str: + expr = self.compile(c, elem.expr) + return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type) + + def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str: + expr = elem.expr + expr_type = expr.type + + if isinstance(expr_type, Native_UUID): + # Normalize first, apply template after (for uuids) + # Needed because min/max(uuid) fails in postgresql + expr = NormalizeAsString(expr, expr_type) + if elem.apply_func is not None: + expr = elem.apply_func(expr) # Apply template using Python's string formatting + + else: + # Apply template before normalizing (for ints) + if elem.apply_func is not None: + expr = elem.apply_func(expr) # Apply template using Python's string formatting + expr = NormalizeAsString(expr, expr_type) + + return self.compile(c, expr) + + def render_checksum(self, c: Compiler, elem: Checksum) -> str: + if len(elem.exprs) > 1: + exprs = [Code(f"coalesce({self.compile(c, expr)}, '')") for expr in elem.exprs] + # exprs = [self.compile(c, e) for e in exprs] + expr = Concat(exprs, "|") + else: + # No need to coalesce - safe to assume that key cannot be null + (expr,) = elem.exprs + expr = self.compile(c, expr) + md5 = self.md5_as_int(expr) + return f"sum({md5})" + + def render_concat(self, c: Compiler, elem: Concat) -> str: + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL + items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs] + assert items + if len(items) == 1: + return items[0] + + if elem.sep: + items = list(join_iter(f"'{elem.sep}'", items)) + return self.concat(items) + + def render_alias(self, c: Compiler, elem: Alias) -> str: + return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}" + + def render_testregex(self, c: Compiler, elem: TestRegex) -> str: + # TODO: move this method to that mixin! raise here instead, unconditionally. + if not isinstance(self, AbstractMixin_Regex): + raise NotImplementedError(f"No regex implementation for database '{c.dialect}'") + regex = self.test_regex(elem.string, elem.pattern) + return self.compile(c, regex) + + def render_count(self, c: Compiler, elem: Count) -> str: + expr = self.compile(c, elem.expr) if elem.expr else "*" + if elem.distinct: + return f"count(distinct {expr})" + return f"count({expr})" + + def render_code(self, c: Compiler, elem: Code) -> str: + if not elem.args: + return elem.code + + args = {k: self.compile(c, v) for k, v in elem.args.items()} + return elem.code.format(**args) + + def render_func(self, c: Compiler, elem: Func) -> str: + args = ", ".join(self.compile(c, e) for e in elem.args) + return f"{elem.name}({args})" + + def render_whenthen(self, c: Compiler, elem: WhenThen) -> str: + return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}" + + def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str: + assert elem.cases + when_thens = " ".join(self.compile(c, case) for case in elem.cases) + else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else "" + return f"CASE {when_thens}{else_expr} END" + + def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str: + a = self.to_comparable(self.compile(c, elem.a), elem.a.type) + b = self.to_comparable(self.compile(c, elem.b), elem.b.type) + return self.is_distinct_from(a, b) + + def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str: + return f"({elem.op}{self.compile(c, elem.expr)})" + + def render_binop(self, c: Compiler, elem: BinOp) -> str: + expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args) + return f"({expr})" + + def render_tablepath(self, c: Compiler, elem: TablePath) -> str: + path = elem.path # c.database._normalize_table_path(self.name) + return ".".join(map(self.quote, path)) + + def render_tablealias(self, c: Compiler, elem: TableAlias) -> str: + return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}" + + def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: + c: Compiler = parent_c.replace(in_select=False) + table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}" + if parent_c.in_select: + table_expr = f"({table_expr}) {c.new_unique_name()}" + elif parent_c.in_join: + table_expr = f"({table_expr})" + return table_expr + + def render_boundnode(self, c: Compiler, elem: BoundNode) -> str: + assert self is elem.database.dialect + return self.compile(c, elem.node) + + def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: + return self.compile(c, elem._get_resolved()) + + def render_select(self, parent_c: Compiler, elem: Select) -> str: + c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table) + compile_fn = functools.partial(self.compile, c) + + columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*" + distinct = "DISTINCT " if elem.distinct else "" + optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else "" + select = f"SELECT {optimizer_hints}{distinct}{columns}" + + if elem.table: + select += " FROM " + self.compile(c, elem.table) + elif self.PLACEHOLDER_TABLE: + select += f" FROM {self.PLACEHOLDER_TABLE}" + + if elem.where_exprs: + select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs)) + + if elem.group_by_exprs: + select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs)) + + if elem.having_exprs: + assert elem.group_by_exprs + select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) + + if elem.order_by_exprs: + select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs)) + + if elem.limit_expr is not None: + has_order_by = bool(elem.order_by_exprs) + select += " " + self.offset_limit(0, elem.limit_expr, has_order_by=has_order_by) + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + def render_join(self, parent_c: Compiler, elem: Join) -> str: + tables = [ + t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables + ] + c = parent_c.add_table_context(*tables, in_join=True, in_select=False) + op = " JOIN " if elem.op is None else f" {elem.op} JOIN " + joined = op.join(self.compile(c, t) for t in tables) + + if elem.on_exprs: + on = " AND ".join(self.compile(c, e) for e in elem.on_exprs) + res = f"{joined} ON {on}" + else: + res = joined + + compile_fn = functools.partial(self.compile, c) + columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns)) + select = f"SELECT {columns} FROM {res}" + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + def render_groupby(self, c: Compiler, elem: GroupBy) -> str: + compile_fn = functools.partial(self.compile, c) + + if elem.values is None: + raise CompileError(".group_by() must be followed by a call to .agg()") + + keys = [str(i + 1) for i in range(len(elem.keys))] + columns = (elem.keys or []) + (elem.values or []) + if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None: + return self.compile( + c, + elem.table.replace( + columns=columns, + group_by_exprs=[Code(k) for k in keys], + having_exprs=elem.having_exprs, + ) + ) + + keys_str = ", ".join(keys) + columns_str = ", ".join(self.compile(c, x) for x in columns) + having_str = ( + " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" + ) + select = ( + f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" + ) + + if c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif c.in_join: + select = f"({select})" + return select + + def render_in(self, c: Compiler, elem: In) -> str: + compile_fn = functools.partial(self.compile, c) + elems = ", ".join(map(compile_fn, elem.list)) + return f"({self.compile(c, elem.expr)} IN ({elems}))" + + def render_cast(self, c: Compiler, elem: Cast) -> str: + return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})" + + def render_random(self, c: Compiler, elem: Random) -> str: + return self.random() + + def render_explain(self, c: Compiler, elem: Explain) -> str: + return self.explain_as_text(self.compile(c, elem.select)) + + def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str: + return self.current_timestamp() + + def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str: + assert isinstance(c, AbstractMixin_TimeTravel) + return self.compile( + c, + # TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it? + c.time_travel( + elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement + ) + ) + + def render_createtable(self, c: Compiler, elem: CreateTable) -> str: + ne = "IF NOT EXISTS " if elem.if_not_exists else "" + if elem.source_table: + return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}" + + schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items()) + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys) + if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY + else "" + ) + return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})" + + def render_droptable(self, c: Compiler, elem: DropTable) -> str: + ie = "IF EXISTS " if elem.if_exists else "" + return f"DROP TABLE {ie}{self.compile(c, elem.path)}" + + def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str: + return f"TRUNCATE TABLE {self.compile(c, elem.path)}" + + def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str: + if isinstance(elem.expr, ConstantTable): + expr = self.constant_values(elem.expr.rows) + else: + expr = self.compile(c, elem.expr) + + columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else "" + + return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}" + def offset_limit( self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None ) -> str: @@ -335,8 +767,7 @@ def name(self): return type(self).__name__ def compile(self, sql_ast): - compiler = Compiler(self) - return compiler.compile(sql_ast) + return self.dialect.compile(Compiler(self), sql_ast) def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' @@ -359,14 +790,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): else: if res_type is None: res_type = sql_ast.type - sql_code = compiler.compile(sql_ast) + sql_code = self.compile(sql_ast) if sql_code is SKIP: return SKIP logger.debug("Running SQL (%s): %s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): - explained_sql = compiler.compile(Explain(sql_ast)) + explained_sql = self.compile(Explain(sql_ast)) explain = self._query(explained_sql) for row in explain: # Most returned a 1-tuple. Presto returns a string diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 667786a7..c40a2b99 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -58,15 +58,15 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database - c = c.replace(root=False) # we're compiling fragments, not full queries + c: Compiler = c.replace(root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): - return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}" elif isinstance(db, Presto): - return f"create table {c.compile(path)} as {c.compile(expr)}" + return f"create table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.compile(path)} as {c.compile(expr)}" + return f"create global temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" else: - return f"create temporary table {c.compile(path)} as {c.compile(expr)}" + return f"create temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" def bool_to_int(x): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 954783de..0013fef7 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -5,13 +5,12 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.utils import join_iter, ArithString +from data_diff.utils import ArithString from data_diff.abcs.compiler import Compilable from data_diff.abcs.database_types import AbstractTable -from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel from data_diff.schema import Schema -from data_diff.queries.compiler import Compiler, cv_params, Root, CompileError +from data_diff.queries.compiler import Compiler from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError from data_diff.abcs.database_types import DbPath @@ -24,6 +23,10 @@ class QB_TypeError(QueryBuilderError): pass +class Root: + "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" + + class ExprNode(Compilable): "Base class for query expression nodes" @@ -54,13 +57,6 @@ class Code(ExprNode, Root): code: str args: Dict[str, Expr] = None - def compile(self, c: Compiler) -> str: - if not self.args: - return self.code - - args = {k: c.compile(v) for k, v in self.args.items()} - return self.code.format(**args) - def _expr_type(e: Expr) -> type: if isinstance(e, ExprNode): @@ -73,9 +69,6 @@ class Alias(ExprNode): expr: Expr name: str - def compile(self, c: Compiler) -> str: - return f"{c.compile(self.expr)} AS {c.dialect.quote(self.name)}" - @property def type(self): return _expr_type(self.expr) @@ -178,32 +171,13 @@ class Concat(ExprNode): exprs: list sep: str = None - def compile(self, c: Compiler) -> str: - # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(Code(c.dialect.to_string(c.compile(expr))))}, '')" for expr in self.exprs] - assert items - if len(items) == 1: - return items[0] - - if self.sep: - items = list(join_iter(f"'{self.sep}'", items)) - return c.dialect.concat(items) - @dataclass class Count(ExprNode): expr: Expr = None distinct: bool = False - type = int - def compile(self, c: Compiler) -> str: - expr = c.compile(self.expr) if self.expr else "*" - if self.distinct: - return f"count(distinct {expr})" - - return f"count({expr})" - class LazyOps: def __add__(self, other): @@ -262,43 +236,24 @@ class TestRegex(ExprNode, LazyOps): string: Expr pattern: Expr - def compile(self, c: Compiler) -> str: - if not isinstance(c.dialect, AbstractMixin_Regex): - raise NotImplementedError(f"No regex implementation for database '{c.database}'") - regex = c.dialect.test_regex(self.string, self.pattern) - return c.compile(regex) - @dataclass(eq=False) class Func(ExprNode, LazyOps): name: str args: Sequence[Expr] - def compile(self, c: Compiler) -> str: - args = ", ".join(c.compile(e) for e in self.args) - return f"{self.name}({args})" - @dataclass class WhenThen(ExprNode): when: Expr then: Expr - def compile(self, c: Compiler) -> str: - return f"WHEN {c.compile(self.when)} THEN {c.compile(self.then)}" - @dataclass class CaseWhen(ExprNode): cases: Sequence[WhenThen] else_expr: Expr = None - def compile(self, c: Compiler) -> str: - assert self.cases - when_thens = " ".join(c.compile(case) for case in self.cases) - else_expr = (" ELSE " + c.compile(self.else_expr)) if self.else_expr is not None else "" - return f"CASE {when_thens}{else_expr} END" - @property def type(self): then_types = {_expr_type(case.then) for case in self.cases} @@ -353,21 +308,12 @@ class IsDistinctFrom(ExprNode, LazyOps): b: Expr type = bool - def compile(self, c: Compiler) -> str: - a = c.dialect.to_comparable(c.compile(self.a), self.a.type) - b = c.dialect.to_comparable(c.compile(self.b), self.b.type) - return c.dialect.is_distinct_from(a, b) - @dataclass(eq=False, order=False) class BinOp(ExprNode, LazyOps): op: str args: Sequence[Expr] - def compile(self, c: Compiler) -> str: - expr = f" {self.op} ".join(c.compile(a) for a in self.args) - return f"({expr})" - @property def type(self): types = {_expr_type(i) for i in self.args} @@ -382,9 +328,6 @@ class UnaryOp(ExprNode, LazyOps): op: str expr: Expr - def compile(self, c: Compiler) -> str: - return f"({self.op}{c.compile(self.expr)})" - class BinBoolOp(BinOp): type = bool @@ -401,22 +344,6 @@ def type(self): raise QueryBuilderError(f"Schema required for table {self.source_table}") return self.source_table.schema[self.name] - def compile(self, c: Compiler) -> str: - if c._table_context: - if len(c._table_context) > 1: - aliases = [ - t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table - ] - if not aliases: - return c.dialect.quote(self.name) - elif len(aliases) > 1: - raise CompileError(f"Too many aliases for column {self.name}") - (alias,) = aliases - - return f"{c.dialect.quote(alias.name)}.{c.dialect.quote(self.name)}" - - return c.dialect.quote(self.name) - @dataclass class TablePath(ExprNode, ITable): @@ -427,10 +354,6 @@ class TablePath(ExprNode, ITable): def source_table(self) -> Self: return self - def compile(self, c: Compiler) -> str: - path = self.path # c.database._normalize_table_path(self.name) - return ".".join(map(c.dialect.quote, path)) - # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): """Returns a query expression to create a new table. @@ -514,9 +437,6 @@ class TableAlias(ExprNode, ITable): source_table: ITable name: str - def compile(self, c: Compiler) -> str: - return f"{c.compile(self.source_table)} {c.dialect.quote(self.name)}" - @dataclass class Join(ExprNode, ITable, Root): @@ -564,29 +484,6 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: # TODO Ensure exprs <= self.columns ? return self.replace(columns=exprs) - def compile(self, parent_c: Compiler) -> str: - tables = [ - t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables - ] - c = parent_c.add_table_context(*tables, in_join=True, in_select=False) - op = " JOIN " if self.op is None else f" {self.op} JOIN " - joined = op.join(c.compile(t) for t in tables) - - if self.on_exprs: - on = " AND ".join(c.compile(e) for e in self.on_exprs) - res = f"{joined} ON {on}" - else: - res = joined - - columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns)) - select = f"SELECT {columns} FROM {res}" - - if parent_c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif parent_c.in_join: - select = f"({select})" - return select - @dataclass class GroupBy(ExprNode, ITable, Root): @@ -619,36 +516,6 @@ def agg(self, *exprs) -> Self: resolve_names(self.table, exprs) return self.replace(values=(self.values or []) + exprs) - def compile(self, c: Compiler) -> str: - if self.values is None: - raise CompileError(".group_by() must be followed by a call to .agg()") - - keys = [str(i + 1) for i in range(len(self.keys))] - columns = (self.keys or []) + (self.values or []) - if isinstance(self.table, Select) and self.table.columns is None and self.table.group_by_exprs is None: - return c.compile( - self.table.replace( - columns=columns, - group_by_exprs=[Code(k) for k in keys], - having_exprs=self.having_exprs, - ) - ) - - keys_str = ", ".join(keys) - columns_str = ", ".join(c.compile(x) for x in columns) - having_str = ( - " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) if self.having_exprs is not None else "" - ) - select = ( - f"SELECT {columns_str} FROM {c.replace(in_select=True).compile(self.table)} GROUP BY {keys_str}{having_str}" - ) - - if c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif c.in_join: - select = f"({select})" - return select - @dataclass class TableOp(ExprNode, ITable, Root): @@ -672,15 +539,6 @@ def schema(self): assert len(s1) == len(s2) return s1 - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=False) - table_expr = f"{c.compile(self.table1)} {self.op} {c.compile(self.table2)}" - if parent_c.in_select: - table_expr = f"({table_expr}) {c.new_unique_name()}" - elif parent_c.in_join: - table_expr = f"({table_expr})" - return table_expr - @dataclass class Select(ExprNode, ITable, Root): @@ -705,42 +563,6 @@ def schema(self): def source_table(self): return self - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True) # .add_table_context(self.table) - - columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" - distinct = "DISTINCT " if self.distinct else "" - optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else "" - select = f"SELECT {optimizer_hints}{distinct}{columns}" - - if self.table: - select += " FROM " + c.compile(self.table) - elif c.dialect.PLACEHOLDER_TABLE: - select += f" FROM {c.dialect.PLACEHOLDER_TABLE}" - - if self.where_exprs: - select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) - - if self.group_by_exprs: - select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs)) - - if self.having_exprs: - assert self.group_by_exprs - select += " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) - - if self.order_by_exprs: - select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) - - if self.limit_expr is not None: - has_order_by = bool(self.order_by_exprs) - select += " " + c.dialect.offset_limit(0, self.limit_expr, has_order_by=has_order_by) - - if parent_c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif parent_c.in_join: - select = f"({select})" - return select - @classmethod def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs @@ -783,16 +605,6 @@ class Cte(ExprNode, ITable): name: str = None params: Sequence[str] = None - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(_table_context=[], in_select=False) - compiled = c.compile(self.source_table) - - name = self.name or parent_c.new_unique_name() - name_params = f"{name}({', '.join(self.params)})" if self.params else name - parent_c._subqueries[name_params] = compiled - - return name - @property def schema(self): # TODO add cte to schema @@ -829,9 +641,6 @@ def _get_resolved(self) -> Expr: raise QueryBuilderError(f"Column not resolved: {self.resolve_name}") return self.resolved - def compile(self, c: Compiler) -> str: - return self._get_resolved().compile(c) - @property def type(self): return self._get_resolved().type @@ -860,58 +669,34 @@ def __getitem__(self, name): class In(ExprNode): expr: Expr list: Sequence[Expr] - type = bool - def compile(self, c: Compiler): - elems = ", ".join(map(c.compile, self.list)) - return f"({c.compile(self.expr)} IN ({elems}))" - @dataclass class Cast(ExprNode): expr: Expr target_type: Expr - def compile(self, c: Compiler) -> str: - return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})" - @dataclass class Random(ExprNode, LazyOps): type = float - def compile(self, c: Compiler) -> str: - return c.dialect.random() - @dataclass class ConstantTable(ExprNode): rows: Sequence[Sequence] - def compile(self, c: Compiler) -> str: - raise NotImplementedError() - - def compile_for_insert(self, c: Compiler): - return c.dialect.constant_values(self.rows) - @dataclass class Explain(ExprNode, Root): select: Select - type = str - def compile(self, c: Compiler) -> str: - return c.dialect.explain_as_text(c.compile(self.select)) - class CurrentTimestamp(ExprNode): type = datetime - def compile(self, c: Compiler) -> str: - return c.dialect.current_timestamp() - @dataclass class TimeTravel(ITable): @@ -921,14 +706,6 @@ class TimeTravel(ITable): offset: int = None statement: str = None - def compile(self, c: Compiler) -> str: - assert isinstance(c, AbstractMixin_TimeTravel) - return c.compile( - c.time_travel( - self.table, before=self.before, timestamp=self.timestamp, offset=self.offset, statement=self.statement - ) - ) - # DDL @@ -944,37 +721,17 @@ class CreateTable(Statement): if_not_exists: bool = False primary_keys: List[str] = None - def compile(self, c: Compiler) -> str: - ne = "IF NOT EXISTS " if self.if_not_exists else "" - if self.source_table: - return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - - schema = ", ".join(f"{c.dialect.quote(k)} {c.dialect.type_repr(v)}" for k, v in self.path.schema.items()) - pks = ( - ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) - if self.primary_keys and c.dialect.SUPPORTS_PRIMARY_KEY - else "" - ) - return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" - @dataclass class DropTable(Statement): path: TablePath if_exists: bool = False - def compile(self, c: Compiler) -> str: - ie = "IF EXISTS " if self.if_exists else "" - return f"DROP TABLE {ie}{c.compile(self.path)}" - @dataclass class TruncateTable(Statement): path: TablePath - def compile(self, c: Compiler) -> str: - return f"TRUNCATE TABLE {c.compile(self.path)}" - @dataclass class InsertToTable(Statement): @@ -983,16 +740,6 @@ class InsertToTable(Statement): columns: List[str] = None returning_exprs: List[str] = None - def compile(self, c: Compiler) -> str: - if isinstance(self.expr, ConstantTable): - expr = self.expr.compile_for_insert(c) - else: - expr = c.compile(self.expr) - - columns = "(%s)" % ", ".join(map(c.dialect.quote, self.columns)) if self.columns is not None else "" - - return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" - def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. @@ -1014,20 +761,12 @@ def returning(self, *exprs) -> Self: class Commit(Statement): """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP.""" - def compile(self, c: Compiler) -> str: - return "COMMIT" if not c.database.is_autocommit else SKIP - @dataclass class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" - name: str @property def source_table(self): return self - - def compile(self, c: Compiler) -> str: - params = cv_params.get() - return c._compile(params[self.name]) diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index b0df04eb..224ad636 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,30 +1,30 @@ import random from dataclasses import field -from datetime import datetime from typing import Any, Dict, Sequence, List from runtype import dataclass from typing_extensions import Self -from data_diff.utils import ArithString from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath -from data_diff.abcs.compiler import AbstractCompiler, Compilable - -import contextvars - -cv_params = contextvars.ContextVar("params") +from data_diff.abcs.compiler import AbstractCompiler class CompileError(Exception): pass -class Root: - "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" - - @dataclass class Compiler(AbstractCompiler): + """ + Compiler bears the context for a single compilation. + + There can be multiple compilation per app run. + There can be multiple compilers in one compilation (with varying contexts). + """ + + # Database is needed to normalize tables. Dialect is needed for recursive compilations. + # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. + # In practice, we currently bind the dialects to the specific database classes. database: AbstractDatabase in_select: bool = False # Compilation runtime flag @@ -40,38 +40,9 @@ class Compiler(AbstractCompiler): def dialect(self) -> AbstractDialect: return self.database.dialect + # TODO: DEPRECATED: Remove once the dialect is used directly in all places. def compile(self, elem, params=None) -> str: - if params: - cv_params.set(params) - - if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): - from data_diff.queries.ast_classes import Select - - elem = Select(columns=[elem]) - - res = self._compile(elem) - if self.root and self._subqueries: - subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) - self._subqueries.clear() - return f"WITH {subq}\n{res}" - return res - - def _compile(self, elem) -> str: - if elem is None: - return "NULL" - elif isinstance(elem, Compilable): - return elem.compile(self.replace(root=False)) - elif isinstance(elem, str): - return f"'{elem}'" - elif isinstance(elem, (int, float)): - return str(elem) - elif isinstance(elem, datetime): - return self.dialect.timestamp_value(elem) - elif isinstance(elem, bytes): - return f"b'{elem.decode()}'" - elif isinstance(elem, ArithString): - return f"'{elem}'" - assert False, elem + return self.dialect.compile(self, elem, params) def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 8e916601..bb0c8299 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,12 +1,10 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" - from typing import Callable, Sequence from runtype import dataclass -from data_diff.abcs.database_types import ColType, Native_UUID +from data_diff.abcs.database_types import ColType -from data_diff.queries.compiler import Compiler -from data_diff.queries.ast_classes import Expr, ExprNode, Concat, Code +from data_diff.queries.ast_classes import Expr, ExprNode @dataclass @@ -15,48 +13,13 @@ class NormalizeAsString(ExprNode): expr_type: ColType = None type = str - def compile(self, c: Compiler) -> str: - expr = c.compile(self.expr) - return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type) - @dataclass class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode apply_func: Callable = None - def compile(self, c: Compiler) -> str: - expr = self.expr - expr_type = expr.type - - if isinstance(expr_type, Native_UUID): - # Normalize first, apply template after (for uuids) - # Needed because min/max(uuid) fails in postgresql - expr = NormalizeAsString(expr, expr_type) - if self.apply_func is not None: - expr = self.apply_func(expr) # Apply template using Python's string formatting - - else: - # Apply template before normalizing (for ints) - if self.apply_func is not None: - expr = self.apply_func(expr) # Apply template using Python's string formatting - expr = NormalizeAsString(expr, expr_type) - - return c.compile(expr) - @dataclass class Checksum(ExprNode): exprs: Sequence[Expr] - - def compile(self, c: Compiler): - if len(self.exprs) > 1: - exprs = [Code(f"coalesce({c.compile(expr)}, '')") for expr in self.exprs] - # exprs = [c.compile(e) for e in exprs] - expr = Concat(exprs, "|") - else: - # No need to coalesce - safe to assume that key cannot be null - (expr,) = self.exprs - expr = c.compile(expr) - md5 = c.dialect.md5_as_int(expr) - return f"sum({md5})"