Skip to content

Commit 0270149

Browse files
refactors; add progress bar; add info
1 parent febb8fb commit 0270149

File tree

8 files changed

+237
-89
lines changed

8 files changed

+237
-89
lines changed

sqlmesh/cli/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -905,10 +905,11 @@ def table_diff(
905905
) -> None:
906906
"""Show the diff between two tables or a selection of models when they are specified."""
907907
source, target = source_to_target.split(":")
908+
select_models = {model} if model else kwargs.pop("select_model", None)
908909
obj.table_diff(
909910
source=source,
910911
target=target,
911-
model_or_snapshot=model,
912+
select_models=select_models,
912913
**kwargs,
913914
)
914915

sqlmesh/core/console.py

+133
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,32 @@ def show_table_diff(
227227
) -> None:
228228
"""Display the table diff between two or multiple tables."""
229229

230+
@abc.abstractmethod
231+
def update_table_diff_progress(self, model: str) -> None:
232+
"""Update table diff progress bar"""
233+
234+
@abc.abstractmethod
235+
def start_table_diff_progress(self, models_to_diff: int) -> None:
236+
"""Start table diff progress bar"""
237+
238+
@abc.abstractmethod
239+
def start_table_diff_model_progress(self, model: str) -> None:
240+
"""Start table diff model progress"""
241+
242+
@abc.abstractmethod
243+
def stop_table_diff_progress(self) -> None:
244+
"""Stop table diff progress bar"""
245+
246+
@abc.abstractmethod
247+
def show_table_diff_details(
248+
self,
249+
models_in_source: t.List[str],
250+
models_in_target: t.List[str],
251+
models_no_diff: t.List[str],
252+
models_to_diff: t.List[str],
253+
) -> None:
254+
"""Display information about which tables are identical and which are diffed"""
255+
230256
@abc.abstractmethod
231257
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
232258
"""Display information about the tables being diffed and how they are being joined"""
@@ -674,6 +700,27 @@ def show_table_diff(
674700
skip_grain_check=skip_grain_check,
675701
)
676702

703+
def update_table_diff_progress(self, model: str) -> None:
704+
pass
705+
706+
def start_table_diff_progress(self, models_to_diff: int) -> None:
707+
pass
708+
709+
def start_table_diff_model_progress(self, model: str) -> None:
710+
pass
711+
712+
def stop_table_diff_progress(self) -> None:
713+
pass
714+
715+
def show_table_diff_details(
716+
self,
717+
models_in_source: t.List[str],
718+
models_in_target: t.List[str],
719+
models_no_diff: t.List[str],
720+
models_to_diff: t.List[str],
721+
) -> None:
722+
pass
723+
677724
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
678725
pass
679726

@@ -773,6 +820,11 @@ def __init__(
773820
self.state_import_snapshot_task: t.Optional[TaskID] = None
774821
self.state_import_environment_task: t.Optional[TaskID] = None
775822

823+
self.table_diff_progress: t.Optional[Progress] = None
824+
self.table_diff_model_progress: t.Optional[Progress] = None
825+
self.table_diff_model_tasks: t.Dict[str, TaskID] = {}
826+
self.table_diff_progress_live: t.Optional[Live] = None
827+
776828
self.verbosity = verbosity
777829
self.dialect = dialect
778830
self.ignore_warnings = ignore_warnings
@@ -1928,6 +1980,87 @@ def loading_stop(self, id: uuid.UUID) -> None:
19281980
self.loading_status[id].stop()
19291981
del self.loading_status[id]
19301982

1983+
def show_table_diff_details(
1984+
self,
1985+
models_in_source: t.List[str],
1986+
models_in_target: t.List[str],
1987+
models_no_diff: t.List[str],
1988+
models_to_diff: t.List[str],
1989+
) -> None:
1990+
"""Display information about which tables are identical and which are diffed"""
1991+
1992+
if models_in_source:
1993+
m_tree = Tree("\n[b]Models only in source environment:")
1994+
for m in models_in_source:
1995+
m_tree.add(f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m}[/{self.TABLE_DIFF_SOURCE_BLUE}]")
1996+
self._print(m_tree)
1997+
1998+
if models_in_target:
1999+
m_tree = Tree("\n[b]Models only in target environment:")
2000+
for m in models_in_target:
2001+
m_tree.add(f"[{self.TABLE_DIFF_TARGET_GREEN}]{m}[/{self.TABLE_DIFF_TARGET_GREEN}]")
2002+
self._print(m_tree)
2003+
2004+
if models_no_diff:
2005+
m_tree = Tree("\n[b]Models without changes:")
2006+
for m in models_no_diff:
2007+
m_tree.add(f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m}[/{self.TABLE_DIFF_SOURCE_BLUE}]")
2008+
self._print(m_tree)
2009+
2010+
if models_to_diff:
2011+
m_tree = Tree("\n[b]Models to compare:")
2012+
for m in models_to_diff:
2013+
m_tree.add(f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m}[/{self.TABLE_DIFF_SOURCE_BLUE}]")
2014+
self._print(m_tree)
2015+
self._print("")
2016+
2017+
def start_table_diff_progress(self, models_to_diff: int) -> None:
2018+
if not self.table_diff_progress:
2019+
self.table_diff_progress = make_progress_bar(
2020+
"Calculating model differences", self.console
2021+
)
2022+
self.table_diff_model_progress = Progress(
2023+
TextColumn("{task.fields[view_name]}", justify="right"),
2024+
SpinnerColumn(spinner_name="simpleDots"),
2025+
console=self.console,
2026+
)
2027+
2028+
progress_table = Table.grid()
2029+
progress_table.add_row(self.table_diff_progress)
2030+
progress_table.add_row(self.table_diff_model_progress)
2031+
2032+
self.table_diff_progress_live = Live(progress_table, refresh_per_second=10)
2033+
self.table_diff_progress_live.start()
2034+
2035+
self.table_diff_model_task = self.table_diff_progress.add_task(
2036+
"Diffing", total=models_to_diff
2037+
)
2038+
2039+
def start_table_diff_model_progress(self, model: str) -> None:
2040+
if self.table_diff_model_progress and model not in self.table_diff_model_tasks:
2041+
self.table_diff_model_tasks[model] = self.table_diff_model_progress.add_task(
2042+
f"Diffing {model}...",
2043+
view_name=model,
2044+
total=1,
2045+
)
2046+
2047+
def update_table_diff_progress(self, model: str) -> None:
2048+
if self.table_diff_progress:
2049+
self.table_diff_progress.update(self.table_diff_model_task, refresh=True, advance=1)
2050+
if self.table_diff_model_progress and model in self.table_diff_model_tasks:
2051+
model_task_id = self.table_diff_model_tasks[model]
2052+
self.table_diff_model_progress.remove_task(model_task_id)
2053+
2054+
def stop_table_diff_progress(self) -> None:
2055+
if self.table_diff_progress_live:
2056+
self.table_diff_progress_live.stop()
2057+
self.table_diff_progress_live = None
2058+
self.log_status_update("")
2059+
self.log_success(f"{GREEN_CHECK_MARK} Table diff completed")
2060+
self.table_diff_progress = None
2061+
self.table_diff_model_progress = None
2062+
self.table_diff_model_tasks = {}
2063+
19312064
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
19322065
tree = Tree("\n[b]Table Diff")
19332066

sqlmesh/core/context.py

+80-59
Original file line numberDiff line numberDiff line change
@@ -1571,8 +1571,7 @@ def table_diff(
15711571
target: str,
15721572
on: t.List[str] | exp.Condition | None = None,
15731573
skip_columns: t.List[str] | None = None,
1574-
model_or_snapshot: t.Optional[ModelOrSnapshot] = None,
1575-
select_model: t.Optional[t.Collection[str]] = None,
1574+
select_models: t.Optional[t.Collection[str]] = None,
15761575
where: t.Optional[str | exp.Condition] = None,
15771576
limit: int = 20,
15781577
show: bool = True,
@@ -1589,7 +1588,7 @@ def table_diff(
15891588
on: The join condition, table aliases must be "s" and "t" for source and target.
15901589
If omitted, the table's grain will be used.
15911590
skip_columns: The columns to skip when computing the table diff.
1592-
model_or_snapshot: The model or snapshot to use when environments are passed in.
1591+
select_models: The modelσ or snapshotσ to use when environments are passed in.
15931592
where: An optional where statement to filter results.
15941593
limit: The limit of the sample dataframe.
15951594
show: Show the table diff output in the console.
@@ -1605,51 +1604,81 @@ def table_diff(
16051604
table_diffs: t.List[TableDiff] = []
16061605

16071606
# Diffs multiple or a single model across two environments
1608-
if model_or_snapshot or select_model:
1607+
if select_models:
16091608
source_env = self.state_reader.get_environment(source)
16101609
target_env = self.state_reader.get_environment(target)
16111610
if not source_env:
16121611
raise SQLMeshError(f"Could not find environment '{source}'")
16131612
if not target_env:
16141613
raise SQLMeshError(f"Could not find environment '{target}'")
16151614

1616-
modified_snapshots: t.Set[ModelOrSnapshot] = (
1617-
{model_or_snapshot} if model_or_snapshot else set()
1618-
)
1619-
if select_model:
1620-
models_to_diff = self._new_selector().expand_model_selections(select_model)
1621-
target_snapshots = {
1622-
s.name: s
1623-
for s in self.state_reader.get_snapshots(target_env.snapshots).values()
1624-
if s.name in models_to_diff
1625-
}
1626-
context_diff = self._context_diff(
1627-
source,
1628-
snapshots=target_snapshots,
1629-
ensure_finalized_snapshots=self.config.plan.use_finalized_state,
1615+
selected_models = self._new_selector().expand_model_selections(select_models)
1616+
models_to_diff: t.List[t.Tuple[Model, EngineAdapter, str, str]] = []
1617+
models_in_source: t.List[str] = []
1618+
models_in_target: t.List[str] = []
1619+
models_no_diff: t.List[str] = []
1620+
1621+
for model_or_snapshot in selected_models:
1622+
model = self.get_model(model_or_snapshot, raise_if_missing=True)
1623+
adapter = self._get_engine_adapter(model.gateway)
1624+
source_snapshot = next(
1625+
(snapshot for snapshot in source_env.snapshots if snapshot.name == model.fqn),
1626+
None,
16301627
)
1631-
modified_snapshots = {
1632-
current_snapshot.snapshot_id.name
1633-
for _, (current_snapshot, _) in context_diff.modified_snapshots.items()
1634-
}
1635-
tasks_num = min(len(modified_snapshots), self.concurrent_tasks)
1636-
table_diffs = concurrent_apply_to_values(
1637-
list(modified_snapshots),
1638-
lambda s: self._model_diff(
1639-
source_env=source_env,
1640-
target_env=target_env,
1641-
model_or_snapshot=s,
1642-
limit=limit,
1643-
decimals=decimals,
1644-
on=on,
1645-
skip_columns=skip_columns,
1646-
where=where,
1647-
show=show,
1648-
temp_schema=temp_schema,
1649-
skip_grain_check=skip_grain_check,
1650-
),
1651-
tasks_num=tasks_num,
1628+
target_snapshot = next(
1629+
(snapshot for snapshot in target_env.snapshots if snapshot.name == model.fqn),
1630+
None,
1631+
)
1632+
if source_snapshot is None and target_snapshot:
1633+
models_in_source.append(model_or_snapshot)
1634+
elif target_snapshot is None and source_snapshot:
1635+
models_in_target.append(model_or_snapshot)
1636+
elif target_snapshot and source_snapshot:
1637+
if source_snapshot.fingerprint != target_snapshot.fingerprint:
1638+
# Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1639+
# to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1640+
source = source_snapshot.qualified_view_name.for_environment(
1641+
source_env.naming_info, adapter.dialect
1642+
)
1643+
target = target_snapshot.qualified_view_name.for_environment(
1644+
target_env.naming_info, adapter.dialect
1645+
)
1646+
1647+
models_to_diff.append((model, adapter, source, target))
1648+
else:
1649+
models_no_diff.append(model_or_snapshot)
1650+
1651+
self.console.show_table_diff_details(
1652+
models_in_source,
1653+
models_in_target,
1654+
models_no_diff,
1655+
[model[0].name for model in models_to_diff],
16521656
)
1657+
1658+
if models_to_diff:
1659+
self.console.start_table_diff_progress(len(models_to_diff))
1660+
tasks_num = min(len(models_to_diff), self.concurrent_tasks)
1661+
table_diffs = concurrent_apply_to_values(
1662+
list(models_to_diff),
1663+
lambda model_info: self._model_diff(
1664+
model=model_info[0],
1665+
adapter=model_info[1],
1666+
source=model_info[2],
1667+
target=model_info[3],
1668+
source_alias=source_env.name,
1669+
target_alias=target_env.name,
1670+
limit=limit,
1671+
decimals=decimals,
1672+
on=on,
1673+
skip_columns=skip_columns,
1674+
where=where,
1675+
show=show,
1676+
temp_schema=temp_schema,
1677+
skip_grain_check=skip_grain_check,
1678+
),
1679+
tasks_num=tasks_num,
1680+
)
1681+
self.console.stop_table_diff_progress()
16531682
else:
16541683
table_diffs = [
16551684
self._table_diff(
@@ -1673,9 +1702,12 @@ def table_diff(
16731702

16741703
def _model_diff(
16751704
self,
1676-
source_env: Environment,
1677-
target_env: Environment,
1678-
model_or_snapshot: ModelOrSnapshot,
1705+
model: Model,
1706+
adapter: EngineAdapter,
1707+
source: str,
1708+
target: str,
1709+
source_alias: str,
1710+
target_alias: str,
16791711
limit: int,
16801712
decimals: int,
16811713
on: t.Optional[t.List[str] | exp.Condition] = None,
@@ -1685,22 +1717,6 @@ def _model_diff(
16851717
temp_schema: t.Optional[str] = None,
16861718
skip_grain_check: bool = False,
16871719
) -> TableDiff:
1688-
model = self.get_model(model_or_snapshot, raise_if_missing=True)
1689-
adapter = self._get_engine_adapter(model.gateway)
1690-
1691-
# Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1692-
# to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
1693-
source = next(
1694-
snapshot for snapshot in source_env.snapshots if snapshot.name == model.fqn
1695-
).qualified_view_name.for_environment(source_env.naming_info, adapter.dialect)
1696-
1697-
target = next(
1698-
snapshot for snapshot in target_env.snapshots if snapshot.name == model.fqn
1699-
).qualified_view_name.for_environment(target_env.naming_info, adapter.dialect)
1700-
1701-
source_alias = source_env.name
1702-
target_alias = target_env.name
1703-
17041720
if not on:
17051721
on = []
17061722
for expr in [ref.expression for ref in model.all_references if ref.unique]:
@@ -1710,6 +1726,8 @@ def _model_diff(
17101726
# Handle a single Column or Paren expression
17111727
on.append(expr.this.sql(dialect=adapter.dialect))
17121728

1729+
self.console.start_table_diff_model_progress(model.name)
1730+
17131731
table_diff = self._table_diff(
17141732
on=on,
17151733
skip_columns=skip_columns,
@@ -1723,10 +1741,13 @@ def _model_diff(
17231741
source_alias=source_alias,
17241742
target_alias=target_alias,
17251743
)
1726-
# Trigger row_diff in parallel execution so it's available for ordered display later
1744+
17271745
if show:
1746+
# Trigger row_diff in parallel execution so it's available for ordered display later
17281747
table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check)
17291748

1749+
self.console.update_table_diff_progress(model.name)
1750+
17301751
return table_diff
17311752

17321753
def _table_diff(

0 commit comments

Comments
 (0)