diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 60c79b10..4ae223fb 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -1,6 +1,7 @@ from typing import Sequence, Tuple, Iterator, Optional, Union from data_diff.abcs.database_types import DbTime, DbPath +from data_diff.databases import Database from data_diff.tracking import disable_tracking from data_diff.databases._connect import connect from data_diff.diff_tables import Algorithm @@ -31,10 +32,10 @@ def connect_to_table( if isinstance(key_columns, str): key_columns = (key_columns,) - db = connect(db_info, thread_count=thread_count) + db: Database = connect(db_info, thread_count=thread_count) if isinstance(table_name, str): - table_name = db.parse_table_name(table_name) + table_name = db.dialect.parse_table_name(table_name) return TableSegment(db, table_name, key_columns, **kwargs) @@ -161,7 +162,8 @@ def diff_tables( ) elif algorithm == Algorithm.JOINDIFF: if isinstance(materialize_to_table, str): - materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table)) + table_name = eval_name_template(materialize_to_table) + materialize_to_table = table1.database.dialect.parse_table_name(table_name) differ = JoinDiffer( threaded=threaded, max_threadpool_size=max_threadpool_size, diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 77dc7fb6..0e5255e6 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -6,12 +6,13 @@ import json import logging from itertools import islice -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import rich from rich.logging import RichHandler import click +from data_diff import Database from data_diff.schema import create_schema from data_diff.queries.api import current_timestamp @@ -425,7 +426,7 @@ def _data_diff( logging.error(f"Error while parsing age expression: {e}") return - dbs = db1, db2 + dbs: Tuple[Database, Database] = db1, db2 if interactive: for db in dbs: @@ -444,7 +445,7 @@ def _data_diff( materialize_all_rows=materialize_all_rows, table_write_limit=table_write_limit, materialize_to_table=materialize_to_table - and db1.parse_table_name(eval_name_template(materialize_to_table)), + and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)), ) else: assert algorithm == Algorithm.HASHDIFF @@ -456,7 +457,7 @@ def _data_diff( ) table_names = table1, table2 - table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] + table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)] schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) schema1, schema2 = schemas = [ diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py index 72fd7578..4a847d05 100644 --- a/data_diff/abcs/compiler.py +++ b/data_diff/abcs/compiler.py @@ -1,15 +1,9 @@ -from typing import Any, Dict -from abc import ABC, abstractmethod +from abc import ABC class AbstractCompiler(ABC): - @abstractmethod - def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: - ... + pass class Compilable(ABC): - # TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T - @abstractmethod - def compile(self, c: AbstractCompiler) -> str: - ... + pass diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 82ec8352..a679db67 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,11 +1,12 @@ import decimal from abc import ABC, abstractmethod -from typing import Sequence, Optional, Tuple, Type, Union, Dict, List +from typing import Sequence, Optional, Tuple, Union, Dict, List from datetime import datetime from runtype import dataclass from typing_extensions import Self +from data_diff.abcs.compiler import AbstractCompiler from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -176,6 +177,14 @@ class UnknownColType(ColType): class AbstractDialect(ABC): """Dialect-dependent query expressions""" + @abstractmethod + def compile(self, compiler: AbstractCompiler, elem, params=None) -> str: + raise NotImplementedError + + @abstractmethod + def parse_table_name(self, name: str) -> DbPath: + "Parse the given table name into a DbPath" + @property @abstractmethod def name(self) -> str: @@ -319,10 +328,6 @@ def _process_table_schema( """ - @abstractmethod - def parse_table_name(self, name: str) -> DbPath: - "Parse the given table name into a DbPath" - @abstractmethod def close(self): "Close connection(s) to the database instance. Querying will stop functioning." diff --git a/data_diff/bound_exprs.py b/data_diff/bound_exprs.py index 1742b74c..4b53846d 100644 --- a/data_diff/bound_exprs.py +++ b/data_diff/bound_exprs.py @@ -8,7 +8,6 @@ from typing_extensions import Self from data_diff.abcs.database_types import AbstractDatabase -from data_diff.abcs.compiler import AbstractCompiler from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable from data_diff.queries.api import table from data_diff.schema import create_schema @@ -37,10 +36,6 @@ def query(self, res_type=list): def type(self): return self.node.type - def compile(self, c: AbstractCompiler) -> str: - assert c.database is self.database - return self.node.compile(c) - def bind_node(node, database): return BoundNode(database, node) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index a89ab74e..dc43d8d7 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,3 +1,4 @@ +import functools from datetime import datetime import math import sys @@ -9,13 +10,25 @@ from abc import abstractmethod from uuid import UUID import decimal +import contextvars from runtype import dataclass from typing_extensions import Self -from data_diff.utils import is_uuid, safezip +from data_diff.queries.compiler import CompileError +from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString +from data_diff.utils import ArithString, is_uuid, join_iter, safezip from data_diff.queries.api import Expr, Compiler, table, Select, SKIP, Explain, Code, this -from data_diff.queries.ast_classes import Random +from data_diff.queries.ast_classes import Alias, BinOp, CaseWhen, Cast, Column, Commit, Concat, ConstantTable, Count, \ + CreateTable, Cte, \ + CurrentTimestamp, DropTable, Func, \ + GroupBy, \ + In, InsertToTable, IsDistinctFrom, \ + Join, \ + Param, \ + Random, \ + Root, TableAlias, TableOp, TablePath, TestRegex, \ + TimeTravel, TruncateTable, UnaryOp, WhenThen, _ResolveColumn from data_diff.abcs.database_types import ( AbstractDatabase, Array, @@ -39,16 +52,17 @@ Boolean, JSON, ) -from data_diff.abcs.mixins import Compilable +from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel, Compilable from data_diff.abcs.mixins import ( AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue, AbstractMixin_OptimizerHints, ) -from data_diff.bound_exprs import bound_table +from data_diff.bound_exprs import BoundNode, bound_table logger = logging.getLogger("database") +cv_params = contextvars.ContextVar("params") def parse_table_name(t): @@ -98,7 +112,7 @@ def __init__(self, compiler: Compiler, gen: Generator): def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: - sql = self.compiler.compile(q) + sql = self.compiler.database.dialect.compile(self.compiler, q) logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) try: try: @@ -156,6 +170,427 @@ class BaseDialect(AbstractDialect): PLACEHOLDER_TABLE = None # Used for Oracle + def parse_table_name(self, name: str) -> DbPath: + return parse_table_name(name) + + def compile(self, compiler: Compiler, elem, params=None) -> str: + if params: + cv_params.set(params) + + if compiler.root and isinstance(elem, Compilable) and not isinstance(elem, Root): + from data_diff.queries.ast_classes import Select + + elem = Select(columns=[elem]) + + res = self._compile(compiler, elem) + if compiler.root and compiler._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items()) + compiler._subqueries.clear() + return f"WITH {subq}\n{res}" + return res + + def _compile(self, compiler: Compiler, elem) -> str: + if elem is None: + return "NULL" + elif isinstance(elem, Compilable): + return self.render_compilable(compiler.replace(root=False), elem) + elif isinstance(elem, str): + return f"'{elem}'" + elif isinstance(elem, (int, float)): + return str(elem) + elif isinstance(elem, datetime): + return self.timestamp_value(elem) + elif isinstance(elem, bytes): + return f"b'{elem.decode()}'" + elif isinstance(elem, ArithString): + return f"'{elem}'" + assert False, elem + + def render_compilable(self, c: Compiler, elem: Compilable) -> str: + # All ifs are only for better code navigation, IDE usage detection, and type checking. + # The last catch-all would render them anyway — it is a typical "visitor" pattern. + if isinstance(elem, Column): + return self.render_column(c, elem) + elif isinstance(elem, Cte): + return self.render_cte(c, elem) + elif isinstance(elem, Commit): + return self.render_commit(c, elem) + elif isinstance(elem, Param): + return self.render_param(c, elem) + elif isinstance(elem, NormalizeAsString): + return self.render_normalizeasstring(c, elem) + elif isinstance(elem, ApplyFuncAndNormalizeAsString): + return self.render_applyfuncandnormalizeasstring(c, elem) + elif isinstance(elem, Checksum): + return self.render_checksum(c, elem) + elif isinstance(elem, Concat): + return self.render_concat(c, elem) + elif isinstance(elem, TestRegex): + return self.render_testregex(c, elem) + elif isinstance(elem, Func): + return self.render_func(c, elem) + elif isinstance(elem, WhenThen): + return self.render_whenthen(c, elem) + elif isinstance(elem, CaseWhen): + return self.render_casewhen(c, elem) + elif isinstance(elem, IsDistinctFrom): + return self.render_isdistinctfrom(c, elem) + elif isinstance(elem, UnaryOp): + return self.render_unaryop(c, elem) + elif isinstance(elem, BinOp): + return self.render_binop(c, elem) + elif isinstance(elem, TablePath): + return self.render_tablepath(c, elem) + elif isinstance(elem, TableAlias): + return self.render_tablealias(c, elem) + elif isinstance(elem, TableOp): + return self.render_tableop(c, elem) + elif isinstance(elem, Select): + return self.render_select(c, elem) + elif isinstance(elem, Join): + return self.render_join(c, elem) + elif isinstance(elem, GroupBy): + return self.render_groupby(c, elem) + elif isinstance(elem, Count): + return self.render_count(c, elem) + elif isinstance(elem, Alias): + return self.render_alias(c, elem) + elif isinstance(elem, In): + return self.render_in(c, elem) + elif isinstance(elem, Cast): + return self.render_cast(c, elem) + elif isinstance(elem, Random): + return self.render_random(c, elem) + elif isinstance(elem, Explain): + return self.render_explain(c, elem) + elif isinstance(elem, CurrentTimestamp): + return self.render_currenttimestamp(c, elem) + elif isinstance(elem, TimeTravel): + return self.render_timetravel(c, elem) + elif isinstance(elem, CreateTable): + return self.render_createtable(c, elem) + elif isinstance(elem, DropTable): + return self.render_droptable(c, elem) + elif isinstance(elem, TruncateTable): + return self.render_truncatetable(c, elem) + elif isinstance(elem, InsertToTable): + return self.render_inserttotable(c, elem) + elif isinstance(elem, Code): + return self.render_code(c, elem) + elif isinstance(elem, BoundNode): + return self.render_boundnode(c, elem) + elif isinstance(elem, _ResolveColumn): + return self.render__resolvecolumn(c, elem) + + method_name = f"render_{elem.__class__.__name__.lower()}" + method = getattr(self, method_name, None) + if method is not None: + return method(c, elem) + else: + raise RuntimeError(f"Cannot render AST of type {elem.__class__}") + # return elem.compile(compiler.replace(root=False)) + + def render_column(self, c: Compiler, elem: Column) -> str: + if c._table_context: + if len(c._table_context) > 1: + aliases = [ + t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is elem.source_table + ] + if not aliases: + return self.quote(elem.name) + elif len(aliases) > 1: + raise CompileError(f"Too many aliases for column {elem.name}") + (alias,) = aliases + + return f"{self.quote(alias.name)}.{self.quote(elem.name)}" + + return self.quote(elem.name) + + def render_cte(self, parent_c: Compiler, elem: Cte) -> str: + c: Compiler = parent_c.replace(_table_context=[], in_select=False) + compiled = self.compile(c, elem.source_table) + + name = elem.name or parent_c.new_unique_name() + name_params = f"{name}({', '.join(elem.params)})" if elem.params else name + parent_c._subqueries[name_params] = compiled + + return name + + def render_commit(self, c: Compiler, elem: Commit) -> str: + return "COMMIT" if not c.database.is_autocommit else SKIP + + def render_param(self, c: Compiler, elem: Param) -> str: + params = cv_params.get() + return self._compile(c, params[elem.name]) + + def render_normalizeasstring(self, c: Compiler, elem: NormalizeAsString) -> str: + expr = self.compile(c, elem.expr) + return self.normalize_value_by_type(expr, elem.expr_type or elem.expr.type) + + def render_applyfuncandnormalizeasstring(self, c: Compiler, elem: ApplyFuncAndNormalizeAsString) -> str: + expr = elem.expr + expr_type = expr.type + + if isinstance(expr_type, Native_UUID): + # Normalize first, apply template after (for uuids) + # Needed because min/max(uuid) fails in postgresql + expr = NormalizeAsString(expr, expr_type) + if elem.apply_func is not None: + expr = elem.apply_func(expr) # Apply template using Python's string formatting + + else: + # Apply template before normalizing (for ints) + if elem.apply_func is not None: + expr = elem.apply_func(expr) # Apply template using Python's string formatting + expr = NormalizeAsString(expr, expr_type) + + return self.compile(c, expr) + + def render_checksum(self, c: Compiler, elem: Checksum) -> str: + if len(elem.exprs) > 1: + exprs = [Code(f"coalesce({self.compile(c, expr)}, '')") for expr in elem.exprs] + # exprs = [self.compile(c, e) for e in exprs] + expr = Concat(exprs, "|") + else: + # No need to coalesce - safe to assume that key cannot be null + (expr,) = elem.exprs + expr = self.compile(c, expr) + md5 = self.md5_as_int(expr) + return f"sum({md5})" + + def render_concat(self, c: Compiler, elem: Concat) -> str: + # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL + items = [f"coalesce({self.compile(c, Code(self.to_string(self.compile(c, expr))))}, '')" for expr in elem.exprs] + assert items + if len(items) == 1: + return items[0] + + if elem.sep: + items = list(join_iter(f"'{elem.sep}'", items)) + return self.concat(items) + + def render_alias(self, c: Compiler, elem: Alias) -> str: + return f"{self.compile(c, elem.expr)} AS {self.quote(elem.name)}" + + def render_testregex(self, c: Compiler, elem: TestRegex) -> str: + # TODO: move this method to that mixin! raise here instead, unconditionally. + if not isinstance(self, AbstractMixin_Regex): + raise NotImplementedError(f"No regex implementation for database '{c.dialect}'") + regex = self.test_regex(elem.string, elem.pattern) + return self.compile(c, regex) + + def render_count(self, c: Compiler, elem: Count) -> str: + expr = self.compile(c, elem.expr) if elem.expr else "*" + if elem.distinct: + return f"count(distinct {expr})" + return f"count({expr})" + + def render_code(self, c: Compiler, elem: Code) -> str: + if not elem.args: + return elem.code + + args = {k: self.compile(c, v) for k, v in elem.args.items()} + return elem.code.format(**args) + + def render_func(self, c: Compiler, elem: Func) -> str: + args = ", ".join(self.compile(c, e) for e in elem.args) + return f"{elem.name}({args})" + + def render_whenthen(self, c: Compiler, elem: WhenThen) -> str: + return f"WHEN {self.compile(c, elem.when)} THEN {self.compile(c, elem.then)}" + + def render_casewhen(self, c: Compiler, elem: CaseWhen) -> str: + assert elem.cases + when_thens = " ".join(self.compile(c, case) for case in elem.cases) + else_expr = (" ELSE " + self.compile(c, elem.else_expr)) if elem.else_expr is not None else "" + return f"CASE {when_thens}{else_expr} END" + + def render_isdistinctfrom(self, c: Compiler, elem: IsDistinctFrom) -> str: + a = self.to_comparable(self.compile(c, elem.a), elem.a.type) + b = self.to_comparable(self.compile(c, elem.b), elem.b.type) + return self.is_distinct_from(a, b) + + def render_unaryop(self, c: Compiler, elem: UnaryOp) -> str: + return f"({elem.op}{self.compile(c, elem.expr)})" + + def render_binop(self, c: Compiler, elem: BinOp) -> str: + expr = f" {elem.op} ".join(self.compile(c, a) for a in elem.args) + return f"({expr})" + + def render_tablepath(self, c: Compiler, elem: TablePath) -> str: + path = elem.path # c.database._normalize_table_path(self.name) + return ".".join(map(self.quote, path)) + + def render_tablealias(self, c: Compiler, elem: TableAlias) -> str: + return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}" + + def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: + c: Compiler = parent_c.replace(in_select=False) + table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}" + if parent_c.in_select: + table_expr = f"({table_expr}) {c.new_unique_name()}" + elif parent_c.in_join: + table_expr = f"({table_expr})" + return table_expr + + def render_boundnode(self, c: Compiler, elem: BoundNode) -> str: + assert self is elem.database.dialect + return self.compile(c, elem.node) + + def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: + return self.compile(c, elem._get_resolved()) + + def render_select(self, parent_c: Compiler, elem: Select) -> str: + c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table) + compile_fn = functools.partial(self.compile, c) + + columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*" + distinct = "DISTINCT " if elem.distinct else "" + optimizer_hints = self.optimizer_hints(elem.optimizer_hints) if elem.optimizer_hints else "" + select = f"SELECT {optimizer_hints}{distinct}{columns}" + + if elem.table: + select += " FROM " + self.compile(c, elem.table) + elif self.PLACEHOLDER_TABLE: + select += f" FROM {self.PLACEHOLDER_TABLE}" + + if elem.where_exprs: + select += " WHERE " + " AND ".join(map(compile_fn, elem.where_exprs)) + + if elem.group_by_exprs: + select += " GROUP BY " + ", ".join(map(compile_fn, elem.group_by_exprs)) + + if elem.having_exprs: + assert elem.group_by_exprs + select += " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) + + if elem.order_by_exprs: + select += " ORDER BY " + ", ".join(map(compile_fn, elem.order_by_exprs)) + + if elem.limit_expr is not None: + has_order_by = bool(elem.order_by_exprs) + select += " " + self.offset_limit(0, elem.limit_expr, has_order_by=has_order_by) + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + def render_join(self, parent_c: Compiler, elem: Join) -> str: + tables = [ + t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables + ] + c = parent_c.add_table_context(*tables, in_join=True, in_select=False) + op = " JOIN " if elem.op is None else f" {elem.op} JOIN " + joined = op.join(self.compile(c, t) for t in tables) + + if elem.on_exprs: + on = " AND ".join(self.compile(c, e) for e in elem.on_exprs) + res = f"{joined} ON {on}" + else: + res = joined + + compile_fn = functools.partial(self.compile, c) + columns = "*" if elem.columns is None else ", ".join(map(compile_fn, elem.columns)) + select = f"SELECT {columns} FROM {res}" + + if parent_c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif parent_c.in_join: + select = f"({select})" + return select + + def render_groupby(self, c: Compiler, elem: GroupBy) -> str: + compile_fn = functools.partial(self.compile, c) + + if elem.values is None: + raise CompileError(".group_by() must be followed by a call to .agg()") + + keys = [str(i + 1) for i in range(len(elem.keys))] + columns = (elem.keys or []) + (elem.values or []) + if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None: + return self.compile( + c, + elem.table.replace( + columns=columns, + group_by_exprs=[Code(k) for k in keys], + having_exprs=elem.having_exprs, + ) + ) + + keys_str = ", ".join(keys) + columns_str = ", ".join(self.compile(c, x) for x in columns) + having_str = ( + " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" + ) + select = ( + f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" + ) + + if c.in_select: + select = f"({select}) {c.new_unique_name()}" + elif c.in_join: + select = f"({select})" + return select + + def render_in(self, c: Compiler, elem: In) -> str: + compile_fn = functools.partial(self.compile, c) + elems = ", ".join(map(compile_fn, elem.list)) + return f"({self.compile(c, elem.expr)} IN ({elems}))" + + def render_cast(self, c: Compiler, elem: Cast) -> str: + return f"cast({self.compile(c, elem.expr)} as {self.compile(c, elem.target_type)})" + + def render_random(self, c: Compiler, elem: Random) -> str: + return self.random() + + def render_explain(self, c: Compiler, elem: Explain) -> str: + return self.explain_as_text(self.compile(c, elem.select)) + + def render_currenttimestamp(self, c: Compiler, elem: CurrentTimestamp) -> str: + return self.current_timestamp() + + def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str: + assert isinstance(c, AbstractMixin_TimeTravel) + return self.compile( + c, + # TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it? + c.time_travel( + elem.table, before=elem.before, timestamp=elem.timestamp, offset=elem.offset, statement=elem.statement + ) + ) + + def render_createtable(self, c: Compiler, elem: CreateTable) -> str: + ne = "IF NOT EXISTS " if elem.if_not_exists else "" + if elem.source_table: + return f"CREATE TABLE {ne}{self.compile(c, elem.path)} AS {self.compile(c, elem.source_table)}" + + schema = ", ".join(f"{self.quote(k)} {self.type_repr(v)}" for k, v in elem.path.schema.items()) + pks = ( + ", PRIMARY KEY (%s)" % ", ".join(elem.primary_keys) + if elem.primary_keys and self.SUPPORTS_PRIMARY_KEY + else "" + ) + return f"CREATE TABLE {ne}{self.compile(c, elem.path)}({schema}{pks})" + + def render_droptable(self, c: Compiler, elem: DropTable) -> str: + ie = "IF EXISTS " if elem.if_exists else "" + return f"DROP TABLE {ie}{self.compile(c, elem.path)}" + + def render_truncatetable(self, c: Compiler, elem: TruncateTable) -> str: + return f"TRUNCATE TABLE {self.compile(c, elem.path)}" + + def render_inserttotable(self, c: Compiler, elem: InsertToTable) -> str: + if isinstance(elem.expr, ConstantTable): + expr = self.constant_values(elem.expr.rows) + else: + expr = self.compile(c, elem.expr) + + columns = "(%s)" % ", ".join(map(self.quote, elem.columns)) if elem.columns is not None else "" + + return f"INSERT INTO {self.compile(c, elem.path)}{columns} {expr}" + def offset_limit( self, offset: Optional[int] = None, limit: Optional[int] = None, has_order_by: Optional[bool] = None ) -> str: @@ -332,8 +767,7 @@ def name(self): return type(self).__name__ def compile(self, sql_ast): - compiler = Compiler(self) - return compiler.compile(sql_ast) + return self.dialect.compile(Compiler(self), sql_ast) def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): """Query the given SQL code/AST, and attempt to convert the result to type 'res_type' @@ -356,14 +790,14 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): else: if res_type is None: res_type = sql_ast.type - sql_code = compiler.compile(sql_ast) + sql_code = self.compile(sql_ast) if sql_code is SKIP: return SKIP logger.debug("Running SQL (%s): %s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): - explained_sql = compiler.compile(Explain(sql_ast)) + explained_sql = self.compile(Explain(sql_ast)) explain = self._query(explained_sql) for row in explain: # Most returned a 1-tuple. Presto returns a string @@ -518,9 +952,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") - def parse_table_name(self, name: str) -> DbPath: - return parse_table_name(name) - def _query_cursor(self, c, sql_code: str) -> QueryResult: assert isinstance(sql_code, str), sql_code try: diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 5925234f..feb98bde 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -212,6 +212,10 @@ def to_comparable(self, value: str, coltype: ColType) -> str: def set_timezone_to_utc(self) -> str: raise NotImplementedError() + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in path if i is not None) + class BigQuery(Database): CONNECT_URI_HELP = "bigquery:///" @@ -288,10 +292,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table" ) - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 1b8aa33a..67d0528d 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -94,6 +94,10 @@ def _convert_db_precision_to_digits(self, p: int) -> int: def set_timezone_to_utc(self) -> str: return "SET TIME ZONE 'UTC'" + def parse_table_name(self, name: str) -> DbPath: + path = parse_table_name(name) + return tuple(i for i in path if i is not None) + class Databricks(ThreadedDatabase): dialect = Dialect() @@ -178,10 +182,6 @@ def _process_table_schema( self._refine_coltypes(path, col_dict, where) return col_dict - def parse_table_name(self, name: str) -> DbPath: - path = parse_table_name(name) - return tuple(i for i in self._normalize_table_path(path) if i is not None) - @property def is_autocommit(self) -> bool: return True diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 667786a7..c40a2b99 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -58,15 +58,15 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database - c = c.replace(root=False) # we're compiling fragments, not full queries + c: Compiler = c.replace(root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): - return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" + return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}" elif isinstance(db, Presto): - return f"create table {c.compile(path)} as {c.compile(expr)}" + return f"create table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" elif isinstance(db, Oracle): - return f"create global temporary table {c.compile(path)} as {c.compile(expr)}" + return f"create global temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" else: - return f"create temporary table {c.compile(path)} as {c.compile(expr)}" + return f"create temporary table {c.dialect.compile(c, path)} as {c.dialect.compile(c, expr)}" def bool_to_int(x): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 70cb355f..0013fef7 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -5,13 +5,12 @@ from runtype import dataclass from typing_extensions import Self -from data_diff.utils import join_iter, ArithString +from data_diff.utils import ArithString from data_diff.abcs.compiler import Compilable from data_diff.abcs.database_types import AbstractTable -from data_diff.abcs.mixins import AbstractMixin_Regex, AbstractMixin_TimeTravel from data_diff.schema import Schema -from data_diff.queries.compiler import Compiler, cv_params, Root, CompileError +from data_diff.queries.compiler import Compiler from data_diff.queries.base import SKIP, args_as_tuple, SqeletonError from data_diff.abcs.database_types import DbPath @@ -24,6 +23,10 @@ class QB_TypeError(QueryBuilderError): pass +class Root: + "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" + + class ExprNode(Compilable): "Base class for query expression nodes" @@ -54,13 +57,6 @@ class Code(ExprNode, Root): code: str args: Dict[str, Expr] = None - def compile(self, c: Compiler) -> str: - if not self.args: - return self.code - - args = {k: c.compile(v) for k, v in self.args.items()} - return self.code.format(**args) - def _expr_type(e: Expr) -> type: if isinstance(e, ExprNode): @@ -73,9 +69,6 @@ class Alias(ExprNode): expr: Expr name: str - def compile(self, c: Compiler) -> str: - return f"{c.compile(self.expr)} AS {c.quote(self.name)}" - @property def type(self): return _expr_type(self.expr) @@ -178,32 +171,13 @@ class Concat(ExprNode): exprs: list sep: str = None - def compile(self, c: Compiler) -> str: - # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(Code(c.dialect.to_string(c.compile(expr))))}, '')" for expr in self.exprs] - assert items - if len(items) == 1: - return items[0] - - if self.sep: - items = list(join_iter(f"'{self.sep}'", items)) - return c.dialect.concat(items) - @dataclass class Count(ExprNode): expr: Expr = None distinct: bool = False - type = int - def compile(self, c: Compiler) -> str: - expr = c.compile(self.expr) if self.expr else "*" - if self.distinct: - return f"count(distinct {expr})" - - return f"count({expr})" - class LazyOps: def __add__(self, other): @@ -262,43 +236,24 @@ class TestRegex(ExprNode, LazyOps): string: Expr pattern: Expr - def compile(self, c: Compiler) -> str: - if not isinstance(c.dialect, AbstractMixin_Regex): - raise NotImplementedError(f"No regex implementation for database '{c.database}'") - regex = c.dialect.test_regex(self.string, self.pattern) - return c.compile(regex) - @dataclass(eq=False) class Func(ExprNode, LazyOps): name: str args: Sequence[Expr] - def compile(self, c: Compiler) -> str: - args = ", ".join(c.compile(e) for e in self.args) - return f"{self.name}({args})" - @dataclass class WhenThen(ExprNode): when: Expr then: Expr - def compile(self, c: Compiler) -> str: - return f"WHEN {c.compile(self.when)} THEN {c.compile(self.then)}" - @dataclass class CaseWhen(ExprNode): cases: Sequence[WhenThen] else_expr: Expr = None - def compile(self, c: Compiler) -> str: - assert self.cases - when_thens = " ".join(c.compile(case) for case in self.cases) - else_expr = (" ELSE " + c.compile(self.else_expr)) if self.else_expr is not None else "" - return f"CASE {when_thens}{else_expr} END" - @property def type(self): then_types = {_expr_type(case.then) for case in self.cases} @@ -353,21 +308,12 @@ class IsDistinctFrom(ExprNode, LazyOps): b: Expr type = bool - def compile(self, c: Compiler) -> str: - a = c.dialect.to_comparable(c.compile(self.a), self.a.type) - b = c.dialect.to_comparable(c.compile(self.b), self.b.type) - return c.dialect.is_distinct_from(a, b) - @dataclass(eq=False, order=False) class BinOp(ExprNode, LazyOps): op: str args: Sequence[Expr] - def compile(self, c: Compiler) -> str: - expr = f" {self.op} ".join(c.compile(a) for a in self.args) - return f"({expr})" - @property def type(self): types = {_expr_type(i) for i in self.args} @@ -382,9 +328,6 @@ class UnaryOp(ExprNode, LazyOps): op: str expr: Expr - def compile(self, c: Compiler) -> str: - return f"({self.op}{c.compile(self.expr)})" - class BinBoolOp(BinOp): type = bool @@ -401,22 +344,6 @@ def type(self): raise QueryBuilderError(f"Schema required for table {self.source_table}") return self.source_table.schema[self.name] - def compile(self, c: Compiler) -> str: - if c._table_context: - if len(c._table_context) > 1: - aliases = [ - t for t in c._table_context if isinstance(t, TableAlias) and t.source_table is self.source_table - ] - if not aliases: - return c.quote(self.name) - elif len(aliases) > 1: - raise CompileError(f"Too many aliases for column {self.name}") - (alias,) = aliases - - return f"{c.quote(alias.name)}.{c.quote(self.name)}" - - return c.quote(self.name) - @dataclass class TablePath(ExprNode, ITable): @@ -427,10 +354,6 @@ class TablePath(ExprNode, ITable): def source_table(self) -> Self: return self - def compile(self, c: Compiler) -> str: - path = self.path # c.database._normalize_table_path(self.name) - return ".".join(map(c.quote, path)) - # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): """Returns a query expression to create a new table. @@ -514,9 +437,6 @@ class TableAlias(ExprNode, ITable): source_table: ITable name: str - def compile(self, c: Compiler) -> str: - return f"{c.compile(self.source_table)} {c.quote(self.name)}" - @dataclass class Join(ExprNode, ITable, Root): @@ -564,29 +484,6 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: # TODO Ensure exprs <= self.columns ? return self.replace(columns=exprs) - def compile(self, parent_c: Compiler) -> str: - tables = [ - t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables - ] - c = parent_c.add_table_context(*tables, in_join=True, in_select=False) - op = " JOIN " if self.op is None else f" {self.op} JOIN " - joined = op.join(c.compile(t) for t in tables) - - if self.on_exprs: - on = " AND ".join(c.compile(e) for e in self.on_exprs) - res = f"{joined} ON {on}" - else: - res = joined - - columns = "*" if self.columns is None else ", ".join(map(c.compile, self.columns)) - select = f"SELECT {columns} FROM {res}" - - if parent_c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif parent_c.in_join: - select = f"({select})" - return select - @dataclass class GroupBy(ExprNode, ITable, Root): @@ -619,36 +516,6 @@ def agg(self, *exprs) -> Self: resolve_names(self.table, exprs) return self.replace(values=(self.values or []) + exprs) - def compile(self, c: Compiler) -> str: - if self.values is None: - raise CompileError(".group_by() must be followed by a call to .agg()") - - keys = [str(i + 1) for i in range(len(self.keys))] - columns = (self.keys or []) + (self.values or []) - if isinstance(self.table, Select) and self.table.columns is None and self.table.group_by_exprs is None: - return c.compile( - self.table.replace( - columns=columns, - group_by_exprs=[Code(k) for k in keys], - having_exprs=self.having_exprs, - ) - ) - - keys_str = ", ".join(keys) - columns_str = ", ".join(c.compile(x) for x in columns) - having_str = ( - " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) if self.having_exprs is not None else "" - ) - select = ( - f"SELECT {columns_str} FROM {c.replace(in_select=True).compile(self.table)} GROUP BY {keys_str}{having_str}" - ) - - if c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif c.in_join: - select = f"({select})" - return select - @dataclass class TableOp(ExprNode, ITable, Root): @@ -672,15 +539,6 @@ def schema(self): assert len(s1) == len(s2) return s1 - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=False) - table_expr = f"{c.compile(self.table1)} {self.op} {c.compile(self.table2)}" - if parent_c.in_select: - table_expr = f"({table_expr}) {c.new_unique_name()}" - elif parent_c.in_join: - table_expr = f"({table_expr})" - return table_expr - @dataclass class Select(ExprNode, ITable, Root): @@ -705,42 +563,6 @@ def schema(self): def source_table(self): return self - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(in_select=True) # .add_table_context(self.table) - - columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*" - distinct = "DISTINCT " if self.distinct else "" - optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else "" - select = f"SELECT {optimizer_hints}{distinct}{columns}" - - if self.table: - select += " FROM " + c.compile(self.table) - elif c.dialect.PLACEHOLDER_TABLE: - select += f" FROM {c.dialect.PLACEHOLDER_TABLE}" - - if self.where_exprs: - select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs)) - - if self.group_by_exprs: - select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs)) - - if self.having_exprs: - assert self.group_by_exprs - select += " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) - - if self.order_by_exprs: - select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs)) - - if self.limit_expr is not None: - has_order_by = bool(self.order_by_exprs) - select += " " + c.dialect.offset_limit(0, self.limit_expr, has_order_by=has_order_by) - - if parent_c.in_select: - select = f"({select}) {c.new_unique_name()}" - elif parent_c.in_join: - select = f"({select})" - return select - @classmethod def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs @@ -783,16 +605,6 @@ class Cte(ExprNode, ITable): name: str = None params: Sequence[str] = None - def compile(self, parent_c: Compiler) -> str: - c = parent_c.replace(_table_context=[], in_select=False) - compiled = c.compile(self.source_table) - - name = self.name or parent_c.new_unique_name() - name_params = f"{name}({', '.join(self.params)})" if self.params else name - parent_c._subqueries[name_params] = compiled - - return name - @property def schema(self): # TODO add cte to schema @@ -829,9 +641,6 @@ def _get_resolved(self) -> Expr: raise QueryBuilderError(f"Column not resolved: {self.resolve_name}") return self.resolved - def compile(self, c: Compiler) -> str: - return self._get_resolved().compile(c) - @property def type(self): return self._get_resolved().type @@ -860,58 +669,34 @@ def __getitem__(self, name): class In(ExprNode): expr: Expr list: Sequence[Expr] - type = bool - def compile(self, c: Compiler): - elems = ", ".join(map(c.compile, self.list)) - return f"({c.compile(self.expr)} IN ({elems}))" - @dataclass class Cast(ExprNode): expr: Expr target_type: Expr - def compile(self, c: Compiler) -> str: - return f"cast({c.compile(self.expr)} as {c.compile(self.target_type)})" - @dataclass class Random(ExprNode, LazyOps): type = float - def compile(self, c: Compiler) -> str: - return c.dialect.random() - @dataclass class ConstantTable(ExprNode): rows: Sequence[Sequence] - def compile(self, c: Compiler) -> str: - raise NotImplementedError() - - def compile_for_insert(self, c: Compiler): - return c.dialect.constant_values(self.rows) - @dataclass class Explain(ExprNode, Root): select: Select - type = str - def compile(self, c: Compiler) -> str: - return c.dialect.explain_as_text(c.compile(self.select)) - class CurrentTimestamp(ExprNode): type = datetime - def compile(self, c: Compiler) -> str: - return c.dialect.current_timestamp() - @dataclass class TimeTravel(ITable): @@ -921,14 +706,6 @@ class TimeTravel(ITable): offset: int = None statement: str = None - def compile(self, c: Compiler) -> str: - assert isinstance(c, AbstractMixin_TimeTravel) - return c.compile( - c.time_travel( - self.table, before=self.before, timestamp=self.timestamp, offset=self.offset, statement=self.statement - ) - ) - # DDL @@ -944,37 +721,17 @@ class CreateTable(Statement): if_not_exists: bool = False primary_keys: List[str] = None - def compile(self, c: Compiler) -> str: - ne = "IF NOT EXISTS " if self.if_not_exists else "" - if self.source_table: - return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}" - - schema = ", ".join(f"{c.dialect.quote(k)} {c.dialect.type_repr(v)}" for k, v in self.path.schema.items()) - pks = ( - ", PRIMARY KEY (%s)" % ", ".join(self.primary_keys) - if self.primary_keys and c.dialect.SUPPORTS_PRIMARY_KEY - else "" - ) - return f"CREATE TABLE {ne}{c.compile(self.path)}({schema}{pks})" - @dataclass class DropTable(Statement): path: TablePath if_exists: bool = False - def compile(self, c: Compiler) -> str: - ie = "IF EXISTS " if self.if_exists else "" - return f"DROP TABLE {ie}{c.compile(self.path)}" - @dataclass class TruncateTable(Statement): path: TablePath - def compile(self, c: Compiler) -> str: - return f"TRUNCATE TABLE {c.compile(self.path)}" - @dataclass class InsertToTable(Statement): @@ -983,16 +740,6 @@ class InsertToTable(Statement): columns: List[str] = None returning_exprs: List[str] = None - def compile(self, c: Compiler) -> str: - if isinstance(self.expr, ConstantTable): - expr = self.expr.compile_for_insert(c) - else: - expr = c.compile(self.expr) - - columns = "(%s)" % ", ".join(map(c.quote, self.columns)) if self.columns is not None else "" - - return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" - def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. @@ -1014,20 +761,12 @@ def returning(self, *exprs) -> Self: class Commit(Statement): """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP.""" - def compile(self, c: Compiler) -> str: - return "COMMIT" if not c.database.is_autocommit else SKIP - @dataclass class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" - name: str @property def source_table(self): return self - - def compile(self, c: Compiler) -> str: - params = cv_params.get() - return c._compile(params[self.name]) diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index e6246236..224ad636 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -1,32 +1,32 @@ import random from dataclasses import field -from datetime import datetime from typing import Any, Dict, Sequence, List from runtype import dataclass from typing_extensions import Self -from data_diff.utils import ArithString from data_diff.abcs.database_types import AbstractDatabase, AbstractDialect, DbPath -from data_diff.abcs.compiler import AbstractCompiler, Compilable - -import contextvars - -cv_params = contextvars.ContextVar("params") +from data_diff.abcs.compiler import AbstractCompiler class CompileError(Exception): pass -class Root: - "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" - - @dataclass class Compiler(AbstractCompiler): + """ + Compiler bears the context for a single compilation. + + There can be multiple compilation per app run. + There can be multiple compilers in one compilation (with varying contexts). + """ + + # Database is needed to normalize tables. Dialect is needed for recursive compilations. + # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. + # In practice, we currently bind the dialects to the specific database classes. database: AbstractDatabase - params: dict = field(default_factory=dict) + in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -40,38 +40,9 @@ class Compiler(AbstractCompiler): def dialect(self) -> AbstractDialect: return self.database.dialect + # TODO: DEPRECATED: Remove once the dialect is used directly in all places. def compile(self, elem, params=None) -> str: - if params: - cv_params.set(params) - - if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): - from data_diff.queries.ast_classes import Select - - elem = Select(columns=[elem]) - - res = self._compile(elem) - if self.root and self._subqueries: - subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) - self._subqueries.clear() - return f"WITH {subq}\n{res}" - return res - - def _compile(self, elem) -> str: - if elem is None: - return "NULL" - elif isinstance(elem, Compilable): - return elem.compile(self.replace(root=False)) - elif isinstance(elem, str): - return f"'{elem}'" - elif isinstance(elem, (int, float)): - return str(elem) - elif isinstance(elem, datetime): - return self.dialect.timestamp_value(elem) - elif isinstance(elem, bytes): - return f"b'{elem.decode()}'" - elif isinstance(elem, ArithString): - return f"'{elem}'" - assert False, elem + return self.dialect.compile(self, elem, params) def new_unique_name(self, prefix="tmp"): self._counter[0] += 1 @@ -79,10 +50,8 @@ def new_unique_name(self, prefix="tmp"): def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 - return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") + table_name = f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}" + return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: return self.replace(_table_context=self._table_context + list(tables), **kw) - - def quote(self, s: str): - return self.dialect.quote(s) diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 8e916601..bb0c8299 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,12 +1,10 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" - from typing import Callable, Sequence from runtype import dataclass -from data_diff.abcs.database_types import ColType, Native_UUID +from data_diff.abcs.database_types import ColType -from data_diff.queries.compiler import Compiler -from data_diff.queries.ast_classes import Expr, ExprNode, Concat, Code +from data_diff.queries.ast_classes import Expr, ExprNode @dataclass @@ -15,48 +13,13 @@ class NormalizeAsString(ExprNode): expr_type: ColType = None type = str - def compile(self, c: Compiler) -> str: - expr = c.compile(self.expr) - return c.dialect.normalize_value_by_type(expr, self.expr_type or self.expr.type) - @dataclass class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode apply_func: Callable = None - def compile(self, c: Compiler) -> str: - expr = self.expr - expr_type = expr.type - - if isinstance(expr_type, Native_UUID): - # Normalize first, apply template after (for uuids) - # Needed because min/max(uuid) fails in postgresql - expr = NormalizeAsString(expr, expr_type) - if self.apply_func is not None: - expr = self.apply_func(expr) # Apply template using Python's string formatting - - else: - # Apply template before normalizing (for ints) - if self.apply_func is not None: - expr = self.apply_func(expr) # Apply template using Python's string formatting - expr = NormalizeAsString(expr, expr_type) - - return c.compile(expr) - @dataclass class Checksum(ExprNode): exprs: Sequence[Expr] - - def compile(self, c: Compiler): - if len(self.exprs) > 1: - exprs = [Code(f"coalesce({c.compile(expr)}, '')") for expr in self.exprs] - # exprs = [c.compile(e) for e in exprs] - expr = Concat(exprs, "|") - else: - # No need to coalesce - safe to assume that key cannot be null - (expr,) = self.exprs - expr = c.compile(expr) - md5 = c.dialect.md5_as_int(expr) - return f"sum({md5})"