diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 08506b2c..2df1d5c3 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -234,7 +234,14 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) - "-s", default=None, metavar="PATH", - help="select dbt resources to compare using dbt selection syntax", + help="select dbt resources to compare using dbt selection syntax.", +) +@click.option( + "--state", + "-s", + default=None, + metavar="PATH", + help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.", ) def main(conf, run, **kw): if kw["table2"] is None and kw["database2"]: @@ -267,6 +274,9 @@ def main(conf, run, **kw): logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT) try: + state = kw.pop("state", None) + if state: + state = os.path.expanduser(state) profiles_dir_override = kw.pop("dbt_profiles_dir", None) if profiles_dir_override: profiles_dir_override = os.path.expanduser(profiles_dir_override) @@ -279,11 +289,12 @@ def main(conf, run, **kw): project_dir_override=project_dir_override, is_cloud=kw["cloud"], dbt_selection=kw["select"], + state=state, ) else: - return _data_diff(dbt_project_dir=project_dir_override, - dbt_profiles_dir=profiles_dir_override, - **kw) + return _data_diff( + dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw + ) except Exception as e: logging.error(e) if kw["debug"]: @@ -324,6 +335,7 @@ def _data_diff( dbt_profiles_dir, dbt_project_dir, select, + state, threads1=None, threads2=None, __conf__=None, diff --git a/data_diff/cloud/data_source.py b/data_diff/cloud/data_source.py index 05331c01..6469c897 100644 --- a/data_diff/cloud/data_source.py +++ b/data_diff/cloud/data_source.py @@ -52,9 +52,9 @@ def _validate_temp_schema(temp_schema: str): def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]: - diff_vars = dbt_parser.get_datadiff_variables() - config_prod_database = diff_vars.get("prod_database") - config_prod_schema = diff_vars.get("prod_schema") + config = dbt_parser.get_datadiff_config() + config_prod_database = config.prod_database + config_prod_schema = config.prod_schema if config_prod_database is not None and config_prod_schema is not None: temp_schema = f"{config_prod_database}.{config_prod_schema}" if db_type == "snowflake": diff --git a/data_diff/dbt.py b/data_diff/dbt.py index a21786d7..ba642708 100644 --- a/data_diff/dbt.py +++ b/data_diff/dbt.py @@ -2,16 +2,17 @@ import re import time import webbrowser -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Tuple, Union import keyring - import pydantic import rich from rich.prompt import Confirm, Prompt +from data_diff.errors import DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError + from . import connect_to_table, diff_tables, Algorithm from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source -from .dbt_parser import DbtParser, PROJECT_FILE +from .dbt_parser import DbtParser, PROJECT_FILE, TDatadiffConfig from .tracking import ( bool_ask_for_email, create_email_signup_event_json, @@ -55,22 +56,21 @@ def dbt_diff( project_dir_override: Optional[str] = None, is_cloud: bool = False, dbt_selection: Optional[str] = None, + state: Optional[str] = None, ) -> None: print_version_info() diff_threads = [] set_entrypoint_name("CLI-dbt") - dbt_parser = DbtParser(profiles_dir_override, project_dir_override) + dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state) models = dbt_parser.get_models(dbt_selection) - datadiff_variables = dbt_parser.get_datadiff_variables() - config_prod_database = datadiff_variables.get("prod_database") - config_prod_schema = datadiff_variables.get("prod_schema") - config_prod_custom_schema = datadiff_variables.get("prod_custom_schema") - datasource_id = datadiff_variables.get("datasource_id") + config = dbt_parser.get_datadiff_config() _initialize_events(dbt_parser.dbt_user_id, dbt_parser.dbt_version, dbt_parser.dbt_project_id) - if datadiff_variables.get("custom_schemas") is not None: - logger.warning( - "vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source" + + if not state and not (config.prod_database or config.prod_schema): + doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project" + raise DataDiffDbtProjectVarsNotFoundError( + f"""vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n\nOr specify a production manifest using the `--state` flag.""" ) if is_cloud: @@ -80,13 +80,13 @@ def dbt_diff( return org_meta = api.get_org_meta() - if datasource_id is None: + if config.datasource_id is None: rich.print("[red]Data source ID not found in dbt_project.yml") is_create_data_source = Confirm.ask("Would you like to create a new data source?") if is_create_data_source: - datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser) + config.datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser) rich.print(f'To use the data source in next runs, please, update your "{PROJECT_FILE}" with a block:') - rich.print(f"[green]vars:\n data_diff:\n datasource_id: {datasource_id}\n") + rich.print(f"[green]vars:\n data_diff:\n datasource_id: {config.datasource_id}\n") rich.print( "Read more about Datafold vars in docs: " "https://docs.datafold.com/os_diff/dbt_integration/#configure-a-data-source\n" @@ -97,7 +97,7 @@ def dbt_diff( "\nvars:\n data_diff:\n datasource_id: 1234" ) - data_source = api.get_data_source(datasource_id) + data_source = api.get_data_source(config.datasource_id) dbt_parser.set_casing_policy_for(connection_type=data_source.type) rich.print("[green][bold]\nDiffs in progress...[/][/]\n") @@ -105,13 +105,21 @@ def dbt_diff( dbt_parser.set_connection() for model in models: - diff_vars = _get_diff_vars( - dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model - ) + diff_vars = _get_diff_vars(dbt_parser, config, model) + + # we won't always have a prod path when using state + # when the model DNE in prod manifest, skip the model diff + if ( + state and len(diff_vars.prod_path) < 2 + ): # < 2 because some providers like databricks can legitimately have *only* 2 + diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path)) + diff_output_str += "[green]New model: nothing to diff![/] \n" + rich.print(diff_output_str) + continue if diff_vars.primary_keys: if is_cloud: - diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta) + diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta) diff_threads.append(diff_thread) else: _local_diff(diff_vars) @@ -129,9 +137,7 @@ def dbt_diff( def _get_diff_vars( dbt_parser: "DbtParser", - config_prod_database: Optional[str], - config_prod_schema: Optional[str], - config_prod_custom_schema: Optional[str], + config: TDatadiffConfig, model, ) -> TDiffVars: dev_database = model.database @@ -139,31 +145,11 @@ def _get_diff_vars( primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key") - # "custom" dbt config database - if model.config.database: - prod_database = model.config.database - elif config_prod_database: - prod_database = config_prod_database + # prod path is constructed via configuration or the prod manifest via --state + if dbt_parser.prod_manifest_obj: + prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj) else: - prod_database = dev_database - - # prod schema name differs from dev schema name - if config_prod_schema: - custom_schema = model.config.schema_ - - # the model has a custom schema config(schema='some_schema') - if custom_schema: - if not config_prod_custom_schema: - raise ValueError( - f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n" - + "For more details see: https://docs.datafold.com/development_testing/open_source" - ) - prod_schema = config_prod_custom_schema.replace("", custom_schema) - # no custom schema, use the default - else: - prod_schema = config_prod_schema - else: - prod_schema = dev_schema + prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema) if dbt_parser.requires_upper: dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x] @@ -187,6 +173,45 @@ def _get_diff_vars( ) +def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple[str, str]: + # "custom" dbt config database + if model.config.database: + prod_database = model.config.database + elif config.prod_database: + prod_database = config.prod_database + else: + prod_database = dev_database + + # prod schema name differs from dev schema name + if config.prod_schema: + custom_schema = model.config.schema_ + + # the model has a custom schema config(schema='some_schema') + if custom_schema: + if not config.prod_custom_schema: + raise DataDiffCustomSchemaNoConfigError( + f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value or utilize the `--state` flag!\n\n" + + "For more details see: https://docs.datafold.com/development_testing/open_source" + ) + prod_schema = config.prod_custom_schema.replace("", custom_schema) + # no custom schema, use the default + else: + prod_schema = config.prod_schema + else: + prod_schema = dev_schema + return prod_database, prod_schema + + +def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]: + prod_database = None + prod_schema = None + prod_model = prod_manifest.nodes.get(model.unique_id, None) + if prod_model: + prod_database = prod_model.database + prod_schema = prod_model.schema_ + return prod_database, prod_schema + + def _local_diff(diff_vars: TDiffVars) -> None: dev_qualified_str = ".".join(diff_vars.dev_path) prod_qualified_str = ".".join(diff_vars.prod_path) diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index 14c8caa8..b5da5df8 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -16,7 +16,6 @@ DataDiffDbtCoreNoRunnerError, DataDiffDbtNoSuccessfulModelsInRunError, DataDiffDbtProfileNotFoundError, - DataDiffDbtProjectVarsNotFoundError, DataDiffDbtRedshiftPasswordOnlyError, DataDiffDbtRunResultsVersionError, DataDiffDbtSelectNoMatchingModelsError, @@ -88,29 +87,52 @@ class TDatadiffModelConfig(pydantic.BaseModel): exclude_columns: List[str] = [] +class TDatadiffConfig(pydantic.BaseModel): + prod_database: Optional[str] = None + prod_schema: Optional[str] = None + prod_custom_schema: Optional[str] = None + datasource_id: Optional[int] = None + + class DbtParser: - def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None: + def __init__( + self, + profiles_dir_override: Optional[str] = None, + project_dir_override: Optional[str] = None, + state: Optional[str] = None, + ) -> None: try_set_dbt_flags() self.dbt_runner = try_get_dbt_runner() self.profiles_dir = Path(profiles_dir_override or default_profiles_dir()) self.project_dir = Path(project_dir_override or default_project_dir()) self.connection = {} self.project_dict = self.get_project_dict() - self.manifest_obj = self.get_manifest_obj() - self.dbt_user_id = self.manifest_obj.metadata.user_id - self.dbt_version = self.manifest_obj.metadata.dbt_version - self.dbt_project_id = self.manifest_obj.metadata.project_id + self.dev_manifest_obj = self.get_manifest_obj(self.project_dir / MANIFEST_PATH) + self.prod_manifest_obj = None + if state: + self.prod_manifest_obj = self.get_manifest_obj(Path(state)) + + self.dbt_user_id = self.dev_manifest_obj.metadata.user_id + self.dbt_version = self.dev_manifest_obj.metadata.dbt_version + self.dbt_project_id = self.dev_manifest_obj.metadata.project_id self.requires_upper = False self.threads = None self.unique_columns = self.get_unique_columns() - def get_datadiff_variables(self) -> dict: - doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project" - exception = DataDiffDbtProjectVarsNotFoundError( - f"vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n" + def get_datadiff_config(self) -> TDatadiffConfig: + data_diff_vars = self.project_dict.get("vars", {}).get("data_diff", {}) + prod_database = data_diff_vars.get("prod_database") + prod_schema = data_diff_vars.get("prod_schema") + prod_custom_schema = data_diff_vars.get("prod_custom_schema") + datasource_id = data_diff_vars.get("datasource_id") + config = TDatadiffConfig( + prod_database=prod_database, + prod_schema=prod_schema, + prod_custom_schema=prod_custom_schema, + datasource_id=datasource_id, ) - vars_dict = get_from_dict_with_raise(self.project_dict, "vars", exception) - return get_from_dict_with_raise(vars_dict, "data_diff", exception) + logger.info(f"config: {config}") + return config def get_datadiff_model_config(self, model_meta: dict) -> TDatadiffModelConfig: where_filter = None @@ -172,7 +194,7 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]: if results.success and results.result: model_list = [json.loads(model)["unique_id"] for model in results.result] - models = [self.manifest_obj.nodes.get(x) for x in model_list] + models = [self.dev_manifest_obj.nodes.get(x) for x in model_list] return models if not results.result: @@ -202,15 +224,17 @@ def get_run_results_models(self): ) success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"] - models = [self.manifest_obj.nodes.get(x) for x in success_models] + models = [self.dev_manifest_obj.nodes.get(x) for x in success_models] if not models: - raise DataDiffDbtNoSuccessfulModelsInRunError("Expected > 0 successful models runs from the last dbt command.") + raise DataDiffDbtNoSuccessfulModelsInRunError( + "Expected > 0 successful models runs from the last dbt command." + ) return models - def get_manifest_obj(self): - with open(self.project_dir / MANIFEST_PATH) as manifest: - logger.info(f"Parsing file {MANIFEST_PATH}") + def get_manifest_obj(self, path: Path): + with open(path) as manifest: + logger.info(f"Parsing file {path}") manifest_dict = json.load(manifest) manifest_obj = parse_manifest(manifest=manifest_dict) return manifest_obj @@ -315,7 +339,9 @@ def set_connection(self): if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get( "method" ) == "iam": - raise DataDiffDbtRedshiftPasswordOnlyError("Only password authentication is currently supported for Redshift.") + raise DataDiffDbtRedshiftPasswordOnlyError( + "Only password authentication is currently supported for Redshift." + ) conn_info = { "driver": conn_type, "host": credentials.get("host"), @@ -386,7 +412,7 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str return [] def get_unique_columns(self) -> Dict[str, Set[str]]: - manifest = self.manifest_obj + manifest = self.dev_manifest_obj cols_by_uid = defaultdict(set) for node in manifest.nodes.values(): try: diff --git a/data_diff/errors.py b/data_diff/errors.py index 63965685..16b5757c 100644 --- a/data_diff/errors.py +++ b/data_diff/errors.py @@ -44,3 +44,7 @@ class DataDiffDbtCoreNoRunnerError(Exception): class DataDiffDbtSelectVersionTooLowError(Exception): "Raised when attempting to use `--select` with a dbt-core version < 1.5." + + +class DataDiffCustomSchemaNoConfigError(Exception): + "Raised when a model has a custom schema, but there is no prod_custom_schema config. (And not using --state)." diff --git a/tests/cloud/__init__.py b/tests/cloud/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cloud/test_data_source.py b/tests/cloud/test_data_source.py index 8c8cebe4..cc524a32 100644 --- a/tests/cloud/test_data_source.py +++ b/tests/cloud/test_data_source.py @@ -21,6 +21,7 @@ _get_temp_schema, _test_data_source, ) +from data_diff.dbt_parser import TDatadiffConfig DATA_SOURCE_CONFIGS = { @@ -145,12 +146,9 @@ def setUp(self) -> None: @parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test) @patch("data_diff.dbt_parser.DbtParser.__new__") def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser): - diff_vars = { - "prod_database": "db", - "prod_schema": "schema", - } - mock_dbt_parser.get_datadiff_variables.return_value = diff_vars - temp_schema = f'{diff_vars["prod_database"]}.{diff_vars["prod_schema"]}' + datadiff_config = TDatadiffConfig(prod_database="db", prod_schema="schema") + mock_dbt_parser.get_datadiff_config.return_value = datadiff_config + temp_schema = f"{datadiff_config.prod_database}.{datadiff_config.prod_schema}" if config.type == "snowflake": temp_schema = temp_schema.upper() elif config.type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}: diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 89b62979..3870e1cc 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -5,6 +5,7 @@ from data_diff.cloud.datafold_api import TCloudApiOrgMeta from data_diff.diff_tables import Algorithm from data_diff.errors import ( + DataDiffCustomSchemaNoConfigError, DataDiffDbtBigQueryOauthOnlyError, DataDiffDbtConnectionNotImplementedError, DataDiffDbtCoreNoRunnerError, @@ -20,6 +21,8 @@ from data_diff.dbt import ( _get_diff_vars, + _get_prod_path_from_config, + _get_prod_path_from_manifest, dbt_diff, _local_diff, _cloud_diff, @@ -30,39 +33,31 @@ from data_diff.dbt_parser import ( RUN_RESULTS_PATH, PROJECT_FILE, + TDatadiffConfig, ) import unittest from unittest.mock import MagicMock, Mock, create_autospec, mock_open, patch, ANY class TestDbtParser(unittest.TestCase): - def test_get_datadiff_variables(self): - expected_dict = {"some_key": "some_value"} - full_dict = {"vars": {"data_diff": expected_dict}} + def test_get_datadiff_config(self): + project_dict = {"vars": {"data_diff": {"prod_database": "a_prod_database"}}} mock_self = Mock() - mock_self.project_dict = full_dict - returned_dict = DbtParser.get_datadiff_variables(mock_self) + mock_self.project_dict = project_dict + config = DbtParser.get_datadiff_config(mock_self) - self.assertEqual(expected_dict, returned_dict) + self.assertEqual(project_dict["vars"]["data_diff"]["prod_database"], config.prod_database) + self.assertEqual(config.prod_schema, None) - def test_get_datadiff_variables_none(self): - none_dict = None + def test_get_datadiff_config_no_config(self): + project_dict = {"key": {"key": "value"}} mock_self = Mock() - mock_self.project_dict = none_dict + mock_self.project_dict = project_dict - with self.assertRaises(DataDiffDbtProjectVarsNotFoundError): - DbtParser.get_datadiff_variables(mock_self) - - def test_get_datadiff_variables_empty(self): - empty_dict = {} - - mock_self = Mock() - mock_self.project_dict = empty_dict - - with self.assertRaises(DataDiffDbtProjectVarsNotFoundError): - DbtParser.get_datadiff_variables(mock_self) + config = DbtParser.get_datadiff_config(mock_self) + self.assertEqual(config, TDatadiffConfig()) def test_get_models(self): mock_self = Mock() @@ -130,7 +125,7 @@ def test_get_run_results_models(self, mock_open, mock_artifact_parser): mock_success_result.status.name = "success" mock_failed_result.status.name = "failed" mock_run_results.results = [mock_success_result, mock_failed_result] - mock_self.manifest_obj.nodes.get.return_value = mock_model + mock_self.dev_manifest_obj.nodes.get.return_value = mock_model models = DbtParser.get_run_results_models(mock_self) @@ -687,11 +682,7 @@ def test_diff_is_cloud( connection = {} threads = None where = "a_string" - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "prod_schema": "prod_schema", - "datasource_id": 1, - } + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_dbt_parser_inst = Mock() mock_model = Mock() mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake") @@ -700,7 +691,7 @@ def test_diff_is_cloud( mock_dbt_parser.return_value = mock_dbt_parser_inst mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -740,17 +731,14 @@ def test_diff_is_cloud_no_ds_id( where = "a_string" mock_dbt_parser_inst = Mock() mock_model = Mock() - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "prod_schema": "prod_schema", - } + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") mock_api = Mock() mock_initialize_api.return_value = mock_api mock_api.get_org_meta.return_value = org_meta mock_dbt_parser.return_value = mock_dbt_parser_inst mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -774,16 +762,35 @@ def test_diff_is_cloud_no_ds_id( mock_local_diff.assert_not_called() mock_print.assert_called_once() + @patch("data_diff.dbt._get_diff_vars") + @patch("data_diff.dbt._local_diff") + @patch("data_diff.dbt._cloud_diff") + @patch("data_diff.dbt_parser.DbtParser.__new__") + def test_diff_no_state_no_config(self, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars): + mock_dbt_parser_inst = Mock() + mock_model = Mock() + config = TDatadiffConfig() + + mock_dbt_parser.return_value = mock_dbt_parser_inst + mock_dbt_parser_inst.get_models.return_value = [mock_model] + mock_dbt_parser_inst.get_datadiff_config.return_value = config + + with self.assertRaises(DataDiffDbtProjectVarsNotFoundError): + dbt_diff() + mock_dbt_parser_inst.get_models.assert_called_once() + mock_dbt_parser_inst.get_datadiff_config.assert_called_once() + + mock_get_diff_vars.assert_not_called() + mock_cloud_diff.assert_not_called() + mock_local_diff.assert_not_called() + @patch("data_diff.dbt._get_diff_vars") @patch("data_diff.dbt._local_diff") @patch("data_diff.dbt._cloud_diff") @patch("data_diff.dbt_parser.DbtParser.__new__") @patch("data_diff.dbt.rich.print") def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars): - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "prod_schema": "prod_schema", - } + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") connection = {} threads = None where = "a_string" @@ -791,7 +798,7 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -812,24 +819,60 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m mock_local_diff.assert_called_once_with(diff_vars) mock_print.assert_not_called() + @patch("data_diff.dbt._get_diff_vars") + @patch("data_diff.dbt._local_diff") + @patch("data_diff.dbt._cloud_diff") + @patch("data_diff.dbt_parser.DbtParser.__new__") + @patch("data_diff.dbt.rich.print") + def test_diff_state_model_dne( + self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars + ): + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") + connection = {} + threads = None + where = "a_string" + mock_dbt_parser_inst = Mock() + mock_dbt_parser.return_value = mock_dbt_parser_inst + mock_model = Mock() + mock_dbt_parser_inst.get_models.return_value = [mock_model] + mock_dbt_parser_inst.get_datadiff_config.return_value = config + mock_dbt_parser_inst.get_datadiff_config.return_value = TDatadiffConfig() + + diff_vars = TDiffVars( + dev_path=["dev_db", "dev_schema", "model"], + prod_path=["model"], + primary_keys=["pks"], + connection=connection, + threads=threads, + where_filter=where, + include_columns=[], + exclude_columns=[], + ) + mock_get_diff_vars.return_value = diff_vars + dbt_diff(is_cloud=False, state="/manifest_path.json") + + mock_dbt_parser_inst.get_models.assert_called_once() + mock_dbt_parser_inst.set_connection.assert_called_once() + mock_cloud_diff.assert_not_called() + mock_local_diff.assert_not_called() + self.assertTrue("nothing to diff" in mock_print.call_args[0][0]) + mock_print.assert_called_once() + @patch("data_diff.dbt._get_diff_vars") @patch("data_diff.dbt._local_diff") @patch("data_diff.dbt._cloud_diff") @patch("data_diff.dbt_parser.DbtParser.__new__") @patch("data_diff.dbt.rich.print") def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars): + config = TDatadiffConfig(prod_database="prod_db") connection = {} threads = None where = "a_string" - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "datasource_id": 1, - } mock_dbt_parser_inst = Mock() mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -858,18 +901,15 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m def test_diff_only_prod_schema( self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars ): + config = TDatadiffConfig(prod_schema="prod_schema") connection = {} threads = None where = "a_string" - expected_dbt_vars_dict = { - "datasource_id": 1, - "prod_schema": "prod_schema", - } mock_dbt_parser_inst = Mock() mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -913,16 +953,12 @@ def test_diff_is_cloud_no_pks( connection = {} threads = None where = "a_string" - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "prod_schema": "prod_schema", - "datasource_id": 1, - } + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", datasource_id=1) mock_api = Mock() mock_initialize_api.return_value = mock_api mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], prod_path=["prod"], @@ -952,19 +988,15 @@ def test_diff_is_cloud_no_pks( def test_diff_not_is_cloud_no_pks( self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars ): + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") connection = {} threads = None where = "a_string" - expected_dbt_vars_dict = { - "prod_database": "prod_db", - "prod_schema": "prod_schema", - "datasource_id": 1, - } mock_dbt_parser_inst = Mock() mock_dbt_parser.return_value = mock_dbt_parser_inst mock_model = Mock() mock_dbt_parser_inst.get_models.return_value = [mock_model] - mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict + mock_dbt_parser_inst.get_datadiff_config.return_value = config diff_vars = TDiffVars( dev_path=["dev"], @@ -984,154 +1016,116 @@ def test_diff_not_is_cloud_no_pks( mock_local_diff.assert_not_called() self.assertEqual(mock_print.call_count, 1) - def test_get_diff_vars_replace_custom_schema(self): - prod_database = "a_prod_db" - prod_schema = "a_prod_schema" - primary_keys = ["a_primary_key"] + def test_get_prod_path_from_config_replace_custom_schema(self): + config = TDatadiffConfig( + prod_database="prod_db", prod_schema="prod_schema", prod_custom_schema="prod_" + ) mock_model = Mock() mock_model.database = "a_dev_db" mock_model.schema_ = "a_custom_schema" mock_model.config.schema_ = mock_model.schema_ mock_model.config.database = None mock_model.alias = "a_model_name" - mock_tdatadiffmodelconfig = Mock() - mock_tdatadiffmodelconfig.where_filter = "where" - mock_tdatadiffmodelconfig.include_columns = ["include"] - mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] - mock_dbt_parser = Mock() - mock_dbt_parser.get_pk_from_model.return_value = primary_keys - mock_dbt_parser.requires_upper = False - mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig - mock_dbt_parser.connection = {} - mock_dbt_parser.threads = 0 mock_model.meta = None - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod_", mock_model) - - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, "prod_" + mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.primary_keys, primary_keys) - self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) - self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) - self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) - self.assertNotIn(prod_schema, diff_vars.prod_path) + prod_database, prod_schema = _get_prod_path_from_config( + config, mock_model, mock_model.database, mock_model.schema_ + ) - mock_dbt_parser.get_pk_from_model.assert_called_once() + self.assertEqual(prod_schema, "prod_" + mock_model.schema_) + self.assertEqual(prod_database, config.prod_database) - def test_get_diff_vars_static_custom_schema(self): + def test_get_prod_path_from_config_static_custom_schema(self): + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", prod_custom_schema="prod") mock_model = Mock() - prod_database = "a_prod_db" - prod_schema = "a_prod_schema" - primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_custom_schema" mock_model.config.database = None mock_model.config.schema_ = mock_model.schema_ mock_model.alias = "a_model_name" - mock_tdatadiffmodelconfig = Mock() - mock_tdatadiffmodelconfig.where_filter = "where" - mock_tdatadiffmodelconfig.include_columns = ["include"] - mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] - mock_dbt_parser = Mock() - mock_dbt_parser.get_pk_from_model.return_value = primary_keys - mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig - mock_dbt_parser.connection = {} - mock_dbt_parser.threads = 0 - mock_dbt_parser.requires_upper = False mock_model.meta = None - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model) + prod_database, prod_schema = _get_prod_path_from_config( + config, mock_model, mock_model.database, mock_model.schema_ + ) - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, "prod", mock_model.alias]) - self.assertEqual(diff_vars.primary_keys, primary_keys) - self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) - self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) - self.assertNotIn(prod_schema, diff_vars.prod_path) - mock_dbt_parser.get_pk_from_model.assert_called_once() + self.assertEqual(prod_schema, config.prod_custom_schema) + self.assertEqual(prod_database, config.prod_database) - def test_get_diff_vars_no_custom_schema_on_model(self): + def test_get_prod_path_from_config_no_custom_schema_on_model(self): + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema", prod_custom_schema="prod") mock_model = Mock() - prod_database = "a_prod_db" - prod_schema = "a_prod_schema" - primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_custom_schema" mock_model.config.schema_ = None mock_model.config.database = None mock_model.alias = "a_model_name" - mock_tdatadiffmodelconfig = Mock() - mock_tdatadiffmodelconfig.where_filter = "where" - mock_tdatadiffmodelconfig.include_columns = ["include"] - mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] - mock_dbt_parser = Mock() - mock_dbt_parser.get_pk_from_model.return_value = primary_keys - mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig - mock_dbt_parser.connection = {} - mock_dbt_parser.threads = 0 - mock_dbt_parser.requires_upper = False mock_model.meta = None - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model) + prod_database, prod_schema = _get_prod_path_from_config( + config, mock_model, mock_model.database, mock_model.schema_ + ) - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, prod_schema, mock_model.alias]) - self.assertEqual(diff_vars.primary_keys, primary_keys) - self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) - self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) - mock_dbt_parser.get_pk_from_model.assert_called_once() + self.assertEqual(prod_schema, config.prod_schema) + self.assertEqual(prod_database, config.prod_database) - def test_get_diff_vars_match_dev_schema(self): + def test_get_prod_path_from_config_match_dev_schema(self): + config = TDatadiffConfig(prod_database="prod_db") mock_model = Mock() - prod_database = "a_prod_db" - primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" mock_model.config.schema_ = None mock_model.config.database = None mock_model.alias = "a_model_name" - mock_tdatadiffmodelconfig = Mock() - mock_tdatadiffmodelconfig.where_filter = "where" - mock_tdatadiffmodelconfig.include_columns = ["include"] - mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] - mock_dbt_parser = Mock() - mock_dbt_parser.get_pk_from_model.return_value = primary_keys - mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig - mock_dbt_parser.connection = {} - mock_dbt_parser.threads = 0 - mock_dbt_parser.requires_upper = False mock_model.meta = None - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model) + prod_database, prod_schema = _get_prod_path_from_config( + config, mock_model, mock_model.database, mock_model.schema_ + ) - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.primary_keys, primary_keys) - self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) - self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) - mock_dbt_parser.get_pk_from_model.assert_called_once() + self.assertEqual(prod_schema, mock_model.schema_) + self.assertEqual(prod_database, config.prod_database) + + def test_get_prod_path_from_manifest_model_exists(self): + mock_model = Mock() + mock_model.unique_id = "unique_model_id" + mock_prod_manifest = Mock() + mock_prod_model = Mock() + mock_prod_manifest.nodes.get.return_value = mock_prod_model + mock_prod_model.database = "prod_db" + mock_prod_model.schema_ = "prod_schema" + prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) + self.assertEqual(prod_database, mock_prod_model.database) + self.assertEqual(prod_schema, mock_prod_model.schema_) + + def test_get_prod_path_from_manifest_model_not_exists(self): + mock_model = Mock() + mock_model.unique_id = "unique_model_id" + mock_prod_manifest = Mock() + mock_prod_model = Mock() + mock_prod_manifest.nodes.get.return_value = None + mock_prod_model.database = "prod_db" + mock_prod_model.schema_ = "prod_schema" + prod_database, prod_schema = _get_prod_path_from_manifest(mock_model, mock_prod_manifest) + self.assertEqual(prod_database, None) + self.assertEqual(prod_schema, None) def test_get_diff_custom_schema_no_config_exception(self): + config = TDatadiffConfig(prod_database="prod_db", prod_schema="prod_schema") mock_model = Mock() - prod_database = "a_prod_db" - prod_schema = "a_prod_schema" - primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" mock_model.config.schema_ = "a_custom_schema" mock_model.alias = "a_model_name" - mock_dbt_parser = Mock() - mock_dbt_parser.get_pk_from_model.return_value = primary_keys - mock_dbt_parser.requires_upper = False - with self.assertRaises(ValueError): - _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, None, mock_model) + with self.assertRaises(DataDiffCustomSchemaNoConfigError): + _get_prod_path_from_config(config, mock_model, mock_model.database, mock_model.schema_) - mock_dbt_parser.get_pk_from_model.assert_called_once() - - def test_get_diff_vars_meta_where(self): + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_meta_where(self, mock_prod_path_from_manifest, mock_prod_path_from_config): + config = TDatadiffConfig(prod_database="prod_db") mock_model = Mock() - prod_database = "a_prod_db" primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" @@ -1148,20 +1142,26 @@ def test_get_diff_vars_meta_where(self): mock_dbt_parser.threads = 0 mock_dbt_parser.get_pk_from_model.return_value = primary_keys mock_dbt_parser.requires_upper = False + mock_dbt_parser.prod_manifest_obj = None + mock_prod_path_from_config.return_value = ("prod_db", "prod_schema") - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model) + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, mock_model.schema_, mock_model.alias]) self.assertEqual(diff_vars.primary_keys, primary_keys) self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter) mock_dbt_parser.get_pk_from_model.assert_called_once() - - def test_get_diff_vars_meta_unrelated(self): + mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_) + mock_prod_path_from_manifest.assert_not_called() + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_config.return_value[0]) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_config.return_value[1]) + + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_meta_unrelated(self, mock_prod_path_from_manifest, mock_prod_path_from_config): + config = TDatadiffConfig(prod_database="prod_db") mock_model = Mock() - prod_database = "a_prod_db" primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" @@ -1178,20 +1178,26 @@ def test_get_diff_vars_meta_unrelated(self): mock_dbt_parser.threads = 0 mock_dbt_parser.get_pk_from_model.return_value = primary_keys mock_dbt_parser.requires_upper = False + mock_dbt_parser.prod_manifest_obj = None + mock_prod_path_from_config.return_value = ("prod_db", "prod_schema") - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model) + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) - self.assertEqual(diff_vars.dev_path, [mock_model.database, mock_model.schema_, mock_model.alias]) - self.assertEqual(diff_vars.prod_path, [prod_database, mock_model.schema_, mock_model.alias]) self.assertEqual(diff_vars.primary_keys, primary_keys) self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter) mock_dbt_parser.get_pk_from_model.assert_called_once() - - def test_get_diff_vars_meta_none(self): + mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_) + mock_prod_path_from_manifest.assert_not_called() + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_config.return_value[0]) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_config.return_value[1]) + + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_meta_none(self, mock_prod_path_from_manifest, mock_prod_path_from_config): + config = TDatadiffConfig(prod_database="prod_db") mock_model = Mock() - prod_database = "a_prod_db" primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" @@ -1209,20 +1215,26 @@ def test_get_diff_vars_meta_none(self): mock_dbt_parser.get_pk_from_model.return_value = primary_keys mock_dbt_parser.requires_upper = False mock_model.meta = None + mock_dbt_parser.prod_manifest_obj = None + mock_prod_path_from_config.return_value = ("prod_db", "prod_schema") - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model) + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) - assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias] - assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias] assert diff_vars.primary_keys == primary_keys assert diff_vars.connection == mock_dbt_parser.connection assert diff_vars.threads == mock_dbt_parser.threads self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter) mock_dbt_parser.get_pk_from_model.assert_called_once() - - def test_get_diff_vars_custom_db(self): + mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_) + mock_prod_path_from_manifest.assert_not_called() + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_config.return_value[0]) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_config.return_value[1]) + + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_custom_db(self, mock_prod_path_from_manifest, mock_prod_path_from_config): + config = TDatadiffConfig(prod_database="prod_db") mock_model = Mock() - prod_database = "a_prod_db" primary_keys = ["a_primary_key"] mock_model.database = "a_dev_db" mock_model.schema_ = "a_schema" @@ -1240,13 +1252,88 @@ def test_get_diff_vars_custom_db(self): mock_dbt_parser.get_pk_from_model.return_value = primary_keys mock_dbt_parser.requires_upper = False mock_model.meta = None + mock_dbt_parser.prod_manifest_obj = None + mock_prod_path_from_config.return_value = ("prod_db", "prod_schema") - diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model) + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) - assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias] - assert diff_vars.prod_path == [mock_model.config.database, mock_model.schema_, mock_model.alias] assert diff_vars.primary_keys == primary_keys assert diff_vars.connection == mock_dbt_parser.connection assert diff_vars.threads == mock_dbt_parser.threads self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter) mock_dbt_parser.get_pk_from_model.assert_called_once() + mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_) + mock_prod_path_from_manifest.assert_not_called() + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_config.return_value[0]) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_config.return_value[1]) + + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_upper(self, mock_prod_path_from_manifest, mock_prod_path_from_config): + config = TDatadiffConfig(prod_database="prod_db") + mock_model = Mock() + primary_keys = ["a_primary_key"] + upper_primary_keys = [x.upper() for x in primary_keys] + mock_model.database = "a_dev_db" + mock_model.schema_ = "a_schema" + mock_model.config.schema_ = None + mock_model.config.database = "custom_database" + mock_model.alias = "a_model_name" + mock_tdatadiffmodelconfig = Mock() + mock_tdatadiffmodelconfig.where_filter = "where" + mock_tdatadiffmodelconfig.include_columns = ["include"] + mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] + mock_dbt_parser = Mock() + mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig + mock_dbt_parser.connection = {} + mock_dbt_parser.threads = 0 + mock_dbt_parser.get_pk_from_model.return_value = primary_keys + mock_dbt_parser.requires_upper = True + mock_model.meta = None + mock_dbt_parser.prod_manifest_obj = None + mock_prod_path_from_config.return_value = ("prod_db", "prod_schema") + + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) + + self.assertEqual(diff_vars.primary_keys, upper_primary_keys) + self.assertEqual(diff_vars.connection, mock_dbt_parser.connection) + self.assertEqual(diff_vars.threads, mock_dbt_parser.threads) + self.assertEqual(diff_vars.where_filter, mock_tdatadiffmodelconfig.where_filter) + mock_dbt_parser.get_pk_from_model.assert_called_once() + mock_prod_path_from_config.assert_called_once_with(config, mock_model, mock_model.database, mock_model.schema_) + mock_prod_path_from_manifest.assert_not_called() + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_config.return_value[0].upper()) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_config.return_value[1].upper()) + + @patch("data_diff.dbt._get_prod_path_from_config") + @patch("data_diff.dbt._get_prod_path_from_manifest") + def test_get_diff_vars_call_get_prod_path_from_manifest( + self, mock_prod_path_from_manifest, mock_prod_path_from_config + ): + config = TDatadiffConfig(prod_database="prod_db") + mock_model = Mock() + primary_keys = ["a_primary_key"] + mock_model.database = "a_dev_db" + mock_model.schema_ = "a_schema" + mock_model.config.schema_ = None + mock_model.config.database = "custom_database" + mock_model.alias = "a_model_name" + mock_tdatadiffmodelconfig = Mock() + mock_tdatadiffmodelconfig.where_filter = "where" + mock_tdatadiffmodelconfig.include_columns = ["include"] + mock_tdatadiffmodelconfig.exclude_columns = ["exclude"] + mock_dbt_parser = Mock() + mock_dbt_parser.get_datadiff_model_config.return_value = mock_tdatadiffmodelconfig + mock_dbt_parser.connection = {} + mock_dbt_parser.threads = 0 + mock_dbt_parser.get_pk_from_model.return_value = primary_keys + mock_dbt_parser.requires_upper = False + mock_model.meta = None + mock_dbt_parser.prod_manifest_obj = {"manifest_key": "manifest_value"} + mock_prod_path_from_manifest.return_value = ("prod_db", "prod_schema") + + diff_vars = _get_diff_vars(mock_dbt_parser, config, mock_model) + + mock_prod_path_from_manifest.assert_called_once_with(mock_model, mock_dbt_parser.prod_manifest_obj) + self.assertEqual(diff_vars.prod_path[0], mock_prod_path_from_manifest.return_value[0]) + self.assertEqual(diff_vars.prod_path[1], mock_prod_path_from_manifest.return_value[1])