Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit acd3b08

Browse files
author
Sergey Vasilyev
committed
Compare ARRAY & STRUCT types in BigQuery (simplistically)
1 parent cfd941f commit acd3b08

File tree

5 files changed

+94
-5
lines changed

5 files changed

+94
-5
lines changed

data_diff/sqeleton/abcs/database_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DbTime = datetime
1414

1515

16+
@dataclass
1617
class ColType:
1718
supported = True
1819

@@ -143,6 +144,18 @@ def __post_init__(self):
143144
assert self.precision == 0
144145

145146

147+
@dataclass
148+
class Array(ColType):
149+
item_type: ColType
150+
151+
152+
@dataclass
153+
class Struct(ColType):
154+
# TODO: Later, add the exact definition of the struct so that we could hash & compare individual fields.
155+
# Meanwhile, we rely on JSON comparison, so the internal structure does not matter.
156+
pass
157+
158+
146159
@dataclass
147160
class UnknownColType(ColType):
148161
text: str
@@ -221,6 +234,10 @@ def parse_type(
221234
) -> ColType:
222235
"Parse type info as returned by the database"
223236

237+
@abstractmethod
238+
def to_comparable(self, value: str, coltype: ColType) -> str:
239+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
240+
224241

225242
from typing import TypeVar, Generic
226243

data_diff/sqeleton/abcs/mixins.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID
2+
from .database_types import Array, Struct, TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID
33
from .compiler import Compilable
44

55

@@ -8,6 +8,11 @@ class AbstractMixin(ABC):
88

99

1010
class AbstractMixin_NormalizeValue(AbstractMixin):
11+
12+
@abstractmethod
13+
def to_comparable(self, value: str, coltype: ColType) -> str:
14+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
15+
1116
@abstractmethod
1217
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
1318
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -43,6 +48,14 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
4348
"""Creates an SQL expression, that converts 'value' to either '0' or '1'."""
4449
return self.to_string(value)
4550

51+
def normalize_array(self, value: str, _coltype: Array) -> str:
52+
"""Creates an SQL expression, that serialized an array into a JSON string."""
53+
return self.to_string(value)
54+
55+
def normalize_struct(self, value: str, _coltype: Struct) -> str:
56+
"""Creates an SQL expression, that serialized a struct into a JSON string."""
57+
return self.to_string(value)
58+
4659
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
4760
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
4861
if isinstance(coltype, String_UUID):
@@ -73,6 +86,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
7386
return self.normalize_uuid(value, coltype)
7487
elif isinstance(coltype, Boolean):
7588
return self.normalize_boolean(value, coltype)
89+
elif isinstance(coltype, Array):
90+
return self.normalize_array(value, coltype)
91+
elif isinstance(coltype, Struct):
92+
return self.normalize_struct(value, coltype)
7693
return self.to_string(value)
7794

7895

data_diff/sqeleton/databases/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def concat(self, items: List[str]) -> str:
164164
joined_exprs = ", ".join(items)
165165
return f"concat({joined_exprs})"
166166

167+
def to_comparable(self, value: str, coltype: ColType) -> str:
168+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
169+
return value
170+
167171
def is_distinct_from(self, a: str, b: str) -> str:
168172
return f"{a} is distinct from {b}"
169173

@@ -228,7 +232,7 @@ def parse_type(
228232
""" """
229233

230234
cls = self._parse_type_repr(type_repr)
231-
if not cls:
235+
if cls is None:
232236
return UnknownColType(type_repr)
233237

234238
if issubclass(cls, TemporalType):

data_diff/sqeleton/databases/bigquery.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import List, Union
1+
import re
2+
from typing import Any, List, Union
23
from ..abcs.database_types import (
3-
Timestamp,
4+
Array,
5+
ColType, Struct, Timestamp,
46
Datetime,
57
Integer,
68
Decimal,
@@ -10,6 +12,7 @@
1012
FractionalType,
1113
TemporalType,
1214
Boolean,
15+
UnknownColType,
1316
)
1417
from ..abcs.mixins import (
1518
AbstractMixin_MD5,
@@ -36,6 +39,7 @@ def md5_as_int(self, s: str) -> str:
3639

3740

3841
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
42+
3943
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4044
if coltype.rounds:
4145
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
@@ -57,6 +61,20 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
5761
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
5862
return self.to_string(f"cast({value} as int)")
5963

64+
def normalize_array(self, value: str, _coltype: Array) -> str:
65+
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
66+
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
67+
# So we do the best effort and compare it as strings, hoping that the JSON forms
68+
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
69+
return f"to_json_string({value})"
70+
71+
def normalize_struct(self, value: str, _coltype: Struct) -> str:
72+
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
73+
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
74+
# So we do the best effort and compare it as strings, hoping that the JSON forms
75+
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
76+
return f"to_json_string({value})"
77+
6078

6179
class Mixin_Schema(AbstractMixin_Schema):
6280
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
@@ -117,6 +135,8 @@ class Dialect(BaseDialect, Mixin_Schema):
117135
# Boolean
118136
"BOOL": Boolean,
119137
}
138+
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
139+
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
120140
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}
121141

122142
def random(self) -> str:
@@ -134,6 +154,35 @@ def type_repr(self, t) -> str:
134154
except KeyError:
135155
return super().type_repr(t)
136156

157+
def parse_type(
158+
self,
159+
table_path: DbPath,
160+
col_name: str,
161+
type_repr: str,
162+
*args: Any, # pass-through args
163+
**kwargs: Any, # pass-through args
164+
) -> ColType:
165+
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
166+
if isinstance(col_type, UnknownColType):
167+
168+
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
169+
if m:
170+
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
171+
col_type = Array(item_type=item_type)
172+
173+
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
174+
if m:
175+
col_type = Struct()
176+
177+
return col_type
178+
179+
def to_comparable(self, value: str, coltype: ColType) -> str:
180+
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
181+
if isinstance(coltype, (Array, Struct)):
182+
return self.normalize_value_by_type(value, coltype)
183+
else:
184+
return super().to_comparable(value, coltype)
185+
137186
def set_timezone_to_utc(self) -> str:
138187
raise NotImplementedError()
139188

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ class IsDistinctFrom(ExprNode, LazyOps):
352352
type = bool
353353

354354
def compile(self, c: Compiler) -> str:
355-
return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b))
355+
a = c.dialect.to_comparable(c.compile(self.a), self.a.type)
356+
b = c.dialect.to_comparable(c.compile(self.b), self.b.type)
357+
return c.dialect.is_distinct_from(a, b)
356358

357359

358360
@dataclass(eq=False, order=False)

0 commit comments

Comments
 (0)