From 530a1f6f05d911357f7cfced1231e8c686706eb8 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 3 Jan 2024 20:57:39 +0100 Subject: [PATCH] Group rows by all columns of composite PKs --- data_diff/hashdiff_tables.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 6a7fe4a3..29508965 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -23,7 +23,7 @@ # Just for local readability: TODO: later switch to real type declarations of these. _Op = Literal["+", "-"] -_PK = Any +_PK = Sequence[Any] _Row = Tuple[Any] @@ -34,6 +34,8 @@ def diff_sets( json_cols: dict = None, columns1: Sequence[str], columns2: Sequence[str], + key_columns1: Sequence[str], + key_columns2: Sequence[str], ignored_columns1: Collection[str], ignored_columns2: Collection[str], ) -> Iterator: @@ -41,17 +43,18 @@ def diff_sets( sa: Set[_Row] = {tuple(val for col, val in safezip(columns1, row) if col not in ignored_columns1) for row in a} sb: Set[_Row] = {tuple(val for col, val in safezip(columns2, row) if col not in ignored_columns2) for row in b} - # The first item is always the key (see TableDiffer.relevant_columns) - # TODO update when we add compound keys to hashdiff + # The first items are always the PK (see TableSegment.relevant_columns) diffs_by_pks: Dict[_PK, List[Tuple[_Op, _Row]]] = defaultdict(list) for row in a: + pk: _PK = tuple(val for col, val in zip(key_columns1, row)) cutrow: _Row = tuple(val for col, val in zip(columns1, row) if col not in ignored_columns1) if cutrow not in sb: - diffs_by_pks[row[0]].append(("-", row)) + diffs_by_pks[pk].append(("-", row)) for row in b: + pk: _PK = tuple(val for col, val in zip(key_columns2, row)) cutrow: _Row = tuple(val for col, val in zip(columns2, row) if col not in ignored_columns2) if cutrow not in sa: - diffs_by_pks[row[0]].append(("+", row)) + diffs_by_pks[pk].append(("+", row)) warned_diff_cols = set() for diffs in (diffs_by_pks[pk] for pk in sorted(diffs_by_pks)): @@ -232,6 +235,8 @@ def _bisect_and_diff_segments( json_cols=json_cols, columns1=table1.relevant_columns, columns2=table2.relevant_columns, + key_columns1=table1.key_columns, + key_columns2=table2.key_columns, ignored_columns1=self.ignored_columns1, ignored_columns2=self.ignored_columns2, )