@@ -1571,8 +1571,7 @@ def table_diff(
1571
1571
target : str ,
1572
1572
on : t .List [str ] | exp .Condition | None = None ,
1573
1573
skip_columns : t .List [str ] | None = None ,
1574
- model_or_snapshot : t .Optional [ModelOrSnapshot ] = None ,
1575
- select_model : t .Optional [t .Collection [str ]] = None ,
1574
+ select_models : t .Optional [t .Collection [str ]] = None ,
1576
1575
where : t .Optional [str | exp .Condition ] = None ,
1577
1576
limit : int = 20 ,
1578
1577
show : bool = True ,
@@ -1589,7 +1588,7 @@ def table_diff(
1589
1588
on: The join condition, table aliases must be "s" and "t" for source and target.
1590
1589
If omitted, the table's grain will be used.
1591
1590
skip_columns: The columns to skip when computing the table diff.
1592
- model_or_snapshot : The model or snapshot to use when environments are passed in.
1591
+ select_models : The modelσ or snapshotσ to use when environments are passed in.
1593
1592
where: An optional where statement to filter results.
1594
1593
limit: The limit of the sample dataframe.
1595
1594
show: Show the table diff output in the console.
@@ -1605,51 +1604,81 @@ def table_diff(
1605
1604
table_diffs : t .List [TableDiff ] = []
1606
1605
1607
1606
# Diffs multiple or a single model across two environments
1608
- if model_or_snapshot or select_model :
1607
+ if select_models :
1609
1608
source_env = self .state_reader .get_environment (source )
1610
1609
target_env = self .state_reader .get_environment (target )
1611
1610
if not source_env :
1612
1611
raise SQLMeshError (f"Could not find environment '{ source } '" )
1613
1612
if not target_env :
1614
1613
raise SQLMeshError (f"Could not find environment '{ target } '" )
1615
1614
1616
- modified_snapshots : t .Set [ModelOrSnapshot ] = (
1617
- {model_or_snapshot } if model_or_snapshot else set ()
1618
- )
1619
- if select_model :
1620
- models_to_diff = self ._new_selector ().expand_model_selections (select_model )
1621
- target_snapshots = {
1622
- s .name : s
1623
- for s in self .state_reader .get_snapshots (target_env .snapshots ).values ()
1624
- if s .name in models_to_diff
1625
- }
1626
- context_diff = self ._context_diff (
1627
- source ,
1628
- snapshots = target_snapshots ,
1629
- ensure_finalized_snapshots = self .config .plan .use_finalized_state ,
1615
+ selected_models = self ._new_selector ().expand_model_selections (select_models )
1616
+ models_to_diff : t .List [t .Tuple [Model , EngineAdapter , str , str ]] = []
1617
+ models_in_source : t .List [str ] = []
1618
+ models_in_target : t .List [str ] = []
1619
+ models_no_diff : t .List [str ] = []
1620
+
1621
+ for model_or_snapshot in selected_models :
1622
+ model = self .get_model (model_or_snapshot , raise_if_missing = True )
1623
+ adapter = self ._get_engine_adapter (model .gateway )
1624
+ source_snapshot = next (
1625
+ (snapshot for snapshot in source_env .snapshots if snapshot .name == model .fqn ),
1626
+ None ,
1630
1627
)
1631
- modified_snapshots = {
1632
- current_snapshot .snapshot_id .name
1633
- for _ , (current_snapshot , _ ) in context_diff .modified_snapshots .items ()
1634
- }
1635
- tasks_num = min (len (modified_snapshots ), self .concurrent_tasks )
1636
- table_diffs = concurrent_apply_to_values (
1637
- list (modified_snapshots ),
1638
- lambda s : self ._model_diff (
1639
- source_env = source_env ,
1640
- target_env = target_env ,
1641
- model_or_snapshot = s ,
1642
- limit = limit ,
1643
- decimals = decimals ,
1644
- on = on ,
1645
- skip_columns = skip_columns ,
1646
- where = where ,
1647
- show = show ,
1648
- temp_schema = temp_schema ,
1649
- skip_grain_check = skip_grain_check ,
1650
- ),
1651
- tasks_num = tasks_num ,
1628
+ target_snapshot = next (
1629
+ (snapshot for snapshot in target_env .snapshots if snapshot .name == model .fqn ),
1630
+ None ,
1631
+ )
1632
+ if source_snapshot is None and target_snapshot :
1633
+ models_in_source .append (model_or_snapshot )
1634
+ elif target_snapshot is None and source_snapshot :
1635
+ models_in_target .append (model_or_snapshot )
1636
+ elif target_snapshot and source_snapshot :
1637
+ if source_snapshot .fingerprint != target_snapshot .fingerprint :
1638
+ # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1639
+ # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1640
+ source = source_snapshot .qualified_view_name .for_environment (
1641
+ source_env .naming_info , adapter .dialect
1642
+ )
1643
+ target = target_snapshot .qualified_view_name .for_environment (
1644
+ target_env .naming_info , adapter .dialect
1645
+ )
1646
+
1647
+ models_to_diff .append ((model , adapter , source , target ))
1648
+ else :
1649
+ models_no_diff .append (model_or_snapshot )
1650
+
1651
+ self .console .show_table_diff_details (
1652
+ models_in_source ,
1653
+ models_in_target ,
1654
+ models_no_diff ,
1655
+ [model [0 ].name for model in models_to_diff ],
1652
1656
)
1657
+
1658
+ if models_to_diff :
1659
+ self .console .start_table_diff_progress (len (models_to_diff ))
1660
+ tasks_num = min (len (models_to_diff ), self .concurrent_tasks )
1661
+ table_diffs = concurrent_apply_to_values (
1662
+ list (models_to_diff ),
1663
+ lambda model_info : self ._model_diff (
1664
+ model = model_info [0 ],
1665
+ adapter = model_info [1 ],
1666
+ source = model_info [2 ],
1667
+ target = model_info [3 ],
1668
+ source_alias = source_env .name ,
1669
+ target_alias = target_env .name ,
1670
+ limit = limit ,
1671
+ decimals = decimals ,
1672
+ on = on ,
1673
+ skip_columns = skip_columns ,
1674
+ where = where ,
1675
+ show = show ,
1676
+ temp_schema = temp_schema ,
1677
+ skip_grain_check = skip_grain_check ,
1678
+ ),
1679
+ tasks_num = tasks_num ,
1680
+ )
1681
+ self .console .stop_table_diff_progress ()
1653
1682
else :
1654
1683
table_diffs = [
1655
1684
self ._table_diff (
@@ -1673,9 +1702,12 @@ def table_diff(
1673
1702
1674
1703
def _model_diff (
1675
1704
self ,
1676
- source_env : Environment ,
1677
- target_env : Environment ,
1678
- model_or_snapshot : ModelOrSnapshot ,
1705
+ model : Model ,
1706
+ adapter : EngineAdapter ,
1707
+ source : str ,
1708
+ target : str ,
1709
+ source_alias : str ,
1710
+ target_alias : str ,
1679
1711
limit : int ,
1680
1712
decimals : int ,
1681
1713
on : t .Optional [t .List [str ] | exp .Condition ] = None ,
@@ -1685,22 +1717,6 @@ def _model_diff(
1685
1717
temp_schema : t .Optional [str ] = None ,
1686
1718
skip_grain_check : bool = False ,
1687
1719
) -> TableDiff :
1688
- model = self .get_model (model_or_snapshot , raise_if_missing = True )
1689
- adapter = self ._get_engine_adapter (model .gateway )
1690
-
1691
- # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1692
- # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1693
- source = next (
1694
- snapshot for snapshot in source_env .snapshots if snapshot .name == model .fqn
1695
- ).qualified_view_name .for_environment (source_env .naming_info , adapter .dialect )
1696
-
1697
- target = next (
1698
- snapshot for snapshot in target_env .snapshots if snapshot .name == model .fqn
1699
- ).qualified_view_name .for_environment (target_env .naming_info , adapter .dialect )
1700
-
1701
- source_alias = source_env .name
1702
- target_alias = target_env .name
1703
-
1704
1720
if not on :
1705
1721
on = []
1706
1722
for expr in [ref .expression for ref in model .all_references if ref .unique ]:
@@ -1710,6 +1726,8 @@ def _model_diff(
1710
1726
# Handle a single Column or Paren expression
1711
1727
on .append (expr .this .sql (dialect = adapter .dialect ))
1712
1728
1729
+ self .console .start_table_diff_model_progress (model .name )
1730
+
1713
1731
table_diff = self ._table_diff (
1714
1732
on = on ,
1715
1733
skip_columns = skip_columns ,
@@ -1723,10 +1741,13 @@ def _model_diff(
1723
1741
source_alias = source_alias ,
1724
1742
target_alias = target_alias ,
1725
1743
)
1726
- # Trigger row_diff in parallel execution so it's available for ordered display later
1744
+
1727
1745
if show :
1746
+ # Trigger row_diff in parallel execution so it's available for ordered display later
1728
1747
table_diff .row_diff (temp_schema = temp_schema , skip_grain_check = skip_grain_check )
1729
1748
1749
+ self .console .update_table_diff_progress (model .name )
1750
+
1730
1751
return table_diff
1731
1752
1732
1753
def _table_diff (
0 commit comments