Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

add --state flag feature to dbt integration #600

Merged
merged 6 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,14 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
"-s",
default=None,
metavar="PATH",
help="select dbt resources to compare using dbt selection syntax",
help="select dbt resources to compare using dbt selection syntax.",
)
@click.option(
"--state",
"-s",
default=None,
metavar="PATH",
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
)
def main(conf, run, **kw):
if kw["table2"] is None and kw["database2"]:
Expand Down Expand Up @@ -267,6 +274,9 @@ def main(conf, run, **kw):
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)

try:
state = kw.pop("state", None)
if state:
state = os.path.expanduser(state)
profiles_dir_override = kw.pop("dbt_profiles_dir", None)
if profiles_dir_override:
profiles_dir_override = os.path.expanduser(profiles_dir_override)
Expand All @@ -279,11 +289,12 @@ def main(conf, run, **kw):
project_dir_override=project_dir_override,
is_cloud=kw["cloud"],
dbt_selection=kw["select"],
state=state,
)
else:
return _data_diff(dbt_project_dir=project_dir_override,
dbt_profiles_dir=profiles_dir_override,
**kw)
return _data_diff(
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
)
except Exception as e:
logging.error(e)
if kw["debug"]:
Expand Down Expand Up @@ -324,6 +335,7 @@ def _data_diff(
dbt_profiles_dir,
dbt_project_dir,
select,
state,
threads1=None,
threads2=None,
__conf__=None,
Expand Down
6 changes: 3 additions & 3 deletions data_diff/cloud/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _validate_temp_schema(temp_schema: str):


def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
diff_vars = dbt_parser.get_datadiff_variables()
config_prod_database = diff_vars.get("prod_database")
config_prod_schema = diff_vars.get("prod_schema")
config = dbt_parser.get_datadiff_config()
config_prod_database = config.prod_database
config_prod_schema = config.prod_schema
if config_prod_database is not None and config_prod_schema is not None:
temp_schema = f"{config_prod_database}.{config_prod_schema}"
if db_type == "snowflake":
Expand Down
119 changes: 72 additions & 47 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import re
import time
import webbrowser
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple, Union
import keyring

import pydantic
import rich
from rich.prompt import Confirm, Prompt

from data_diff.errors import DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError

from . import connect_to_table, diff_tables, Algorithm
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
from .dbt_parser import DbtParser, PROJECT_FILE
from .dbt_parser import DbtParser, PROJECT_FILE, TDatadiffConfig
from .tracking import (
bool_ask_for_email,
create_email_#_event_json,
Expand Down Expand Up @@ -55,22 +56,21 @@ def dbt_diff(
project_dir_override: Optional[str] = None,
is_cloud: bool = False,
dbt_selection: Optional[str] = None,
state: Optional[str] = None,
) -> None:
print_version_info()
diff_threads = []
set_entrypoint_name("CLI-dbt")
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
models = dbt_parser.get_models(dbt_selection)
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
datasource_id = datadiff_variables.get("datasource_id")
config = dbt_parser.get_datadiff_config()
_initialize_events(dbt_parser.dbt_user_id, dbt_parser.dbt_version, dbt_parser.dbt_project_id)

if datadiff_variables.get("custom_schemas") is not None:
logger.warning(
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"

if not state and not (config.prod_database or config.prod_schema):
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
raise DataDiffDbtProjectVarsNotFoundError(
f"""vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n\nOr specify a production manifest using the `--state` flag."""
)

if is_cloud:
Expand All @@ -80,13 +80,13 @@ def dbt_diff(
return
org_meta = api.get_org_meta()

if datasource_id is None:
if config.datasource_id is None:
rich.print("[red]Data source ID not found in dbt_project.yml")
is_create_data_source = Confirm.ask("Would you like to create a new data source?")
if is_create_data_source:
datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
config.datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
rich.print(f'To use the data source in next runs, please, update your "{PROJECT_FILE}" with a block:')
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {datasource_id}\n")
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {config.datasource_id}\n")
rich.print(
"Read more about Datafold vars in docs: "
"https://docs.datafold.com/os_diff/dbt_integration/#configure-a-data-source\n"
Expand All @@ -97,21 +97,29 @@ def dbt_diff(
"\nvars:\n data_diff:\n datasource_id: 1234"
)

data_source = api.get_data_source(datasource_id)
data_source = api.get_data_source(config.datasource_id)
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")

else:
dbt_parser.set_connection()

for model in models:
diff_vars = _get_diff_vars(
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
)
diff_vars = _get_diff_vars(dbt_parser, config, model)

# we won't always have a prod path when using state
# when the model DNE in prod manifest, skip the model diff
if (
state and len(diff_vars.prod_path) < 2
): # < 2 because some providers like databricks can legitimately have *only* 2
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
diff_output_str += "[green]New model: nothing to diff![/] \n"
rich.print(diff_output_str)
continue

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars)
Expand All @@ -129,41 +137,19 @@ def dbt_diff(

def _get_diff_vars(
dbt_parser: "DbtParser",
config_prod_database: Optional[str],
config_prod_schema: Optional[str],
config_prod_custom_schema: Optional[str],
config: TDatadiffConfig,
model,
) -> TDiffVars:
dev_database = model.database
dev_schema = model.schema_

primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")

# "custom" dbt config database
if model.config.database:
prod_database = model.config.database
elif config_prod_database:
prod_database = config_prod_database
# prod path is constructed via configuration or the prod manifest via --state
if dbt_parser.prod_manifest_obj:
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
else:
prod_database = dev_database

# prod schema name differs from dev schema name
if config_prod_schema:
custom_schema = model.config.schema_

# the model has a custom schema config(schema='some_schema')
if custom_schema:
if not config_prod_custom_schema:
raise ValueError(
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
)
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
# no custom schema, use the default
else:
prod_schema = config_prod_schema
else:
prod_schema = dev_schema
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)

if dbt_parser.requires_upper:
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
Expand All @@ -187,6 +173,45 @@ def _get_diff_vars(
)


def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple[str, str]:
# "custom" dbt config database
if model.config.database:
prod_database = model.config.database
elif config.prod_database:
prod_database = config.prod_database
else:
prod_database = dev_database

# prod schema name differs from dev schema name
if config.prod_schema:
custom_schema = model.config.schema_

# the model has a custom schema config(schema='some_schema')
if custom_schema:
if not config.prod_custom_schema:
raise DataDiffCustomSchemaNoConfigError(
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value or utilize the `--state` flag!\n\n"
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
)
prod_schema = config.prod_custom_schema.replace("<custom_schema>", custom_schema)
# no custom schema, use the default
else:
prod_schema = config.prod_schema
else:
prod_schema = dev_schema
return prod_database, prod_schema


def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
prod_database = None
prod_schema = None
prod_model = prod_manifest.nodes.get(model.unique_id, None)
if prod_model:
prod_database = prod_model.database
prod_schema = prod_model.schema_
return prod_database, prod_schema


def _local_diff(diff_vars: TDiffVars) -> None:
dev_qualified_str = ".".join(diff_vars.dev_path)
prod_qualified_str = ".".join(diff_vars.prod_path)
Expand Down
66 changes: 46 additions & 20 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
DataDiffDbtCoreNoRunnerError,
DataDiffDbtNoSuccessfulModelsInRunError,
DataDiffDbtProfileNotFoundError,
DataDiffDbtProjectVarsNotFoundError,
DataDiffDbtRedshiftPasswordOnlyError,
DataDiffDbtRunResultsVersionError,
DataDiffDbtSelectNoMatchingModelsError,
Expand Down Expand Up @@ -88,29 +87,52 @@ class TDatadiffModelConfig(pydantic.BaseModel):
exclude_columns: List[str] = []


class TDatadiffConfig(pydantic.BaseModel):
prod_database: Optional[str] = None
prod_schema: Optional[str] = None
prod_custom_schema: Optional[str] = None
datasource_id: Optional[int] = None


class DbtParser:
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
def __init__(
self,
profiles_dir_override: Optional[str] = None,
project_dir_override: Optional[str] = None,
state: Optional[str] = None,
) -> None:
try_set_dbt_flags()
self.dbt_runner = try_get_dbt_runner()
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
self.project_dir = Path(project_dir_override or default_project_dir())
self.connection = {}
self.project_dict = self.get_project_dict()
self.manifest_obj = self.get_manifest_obj()
self.dbt_user_id = self.manifest_obj.metadata.user_id
self.dbt_version = self.manifest_obj.metadata.dbt_version
self.dbt_project_id = self.manifest_obj.metadata.project_id
self.dev_manifest_obj = self.get_manifest_obj(self.project_dir / MANIFEST_PATH)
self.prod_manifest_obj = None
if state:
self.prod_manifest_obj = self.get_manifest_obj(Path(state))

self.dbt_user_id = self.dev_manifest_obj.metadata.user_id
self.dbt_version = self.dev_manifest_obj.metadata.dbt_version
self.dbt_project_id = self.dev_manifest_obj.metadata.project_id
self.requires_upper = False
self.threads = None
self.unique_columns = self.get_unique_columns()

def get_datadiff_variables(self) -> dict:
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
exception = DataDiffDbtProjectVarsNotFoundError(
f"vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n"
def get_datadiff_config(self) -> TDatadiffConfig:
data_diff_vars = self.project_dict.get("vars", {}).get("data_diff", {})
prod_database = data_diff_vars.get("prod_database")
prod_schema = data_diff_vars.get("prod_schema")
prod_custom_schema = data_diff_vars.get("prod_custom_schema")
datasource_id = data_diff_vars.get("datasource_id")
config = TDatadiffConfig(
prod_database=prod_database,
prod_schema=prod_schema,
prod_custom_schema=prod_custom_schema,
datasource_id=datasource_id,
)
vars_dict = get_from_dict_with_raise(self.project_dict, "vars", exception)
return get_from_dict_with_raise(vars_dict, "data_diff", exception)
logger.info(f"config: {config}")
return config

def get_datadiff_model_config(self, model_meta: dict) -> TDatadiffModelConfig:
where_filter = None
Expand Down Expand Up @@ -172,7 +194,7 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:

if results.success and results.result:
model_list = [json.loads(model)["unique_id"] for model in results.result]
models = [self.manifest_obj.nodes.get(x) for x in model_list]
models = [self.dev_manifest_obj.nodes.get(x) for x in model_list]
return models

if not results.result:
Expand Down Expand Up @@ -202,15 +224,17 @@ def get_run_results_models(self):
)

success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
models = [self.manifest_obj.nodes.get(x) for x in success_models]
models = [self.dev_manifest_obj.nodes.get(x) for x in success_models]
if not models:
raise DataDiffDbtNoSuccessfulModelsInRunError("Expected > 0 successful models runs from the last dbt command.")
raise DataDiffDbtNoSuccessfulModelsInRunError(
"Expected > 0 successful models runs from the last dbt command."
)

return models

def get_manifest_obj(self):
with open(self.project_dir / MANIFEST_PATH) as manifest:
logger.info(f"Parsing file {MANIFEST_PATH}")
def get_manifest_obj(self, path: Path):
with open(path) as manifest:
logger.info(f"Parsing file {path}")
manifest_dict = json.load(manifest)
manifest_obj = parse_manifest(manifest=manifest_dict)
return manifest_obj
Expand Down Expand Up @@ -315,7 +339,9 @@ def set_connection(self):
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
"method"
) == "iam":
raise DataDiffDbtRedshiftPasswordOnlyError("Only password authentication is currently supported for Redshift.")
raise DataDiffDbtRedshiftPasswordOnlyError(
"Only password authentication is currently supported for Redshift."
)
conn_info = {
"driver": conn_type,
"host": credentials.get("host"),
Expand Down Expand Up @@ -386,7 +412,7 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str
return []

def get_unique_columns(self) -> Dict[str, Set[str]]:
manifest = self.manifest_obj
manifest = self.dev_manifest_obj
cols_by_uid = defaultdict(set)
for node in manifest.nodes.values():
try:
Expand Down
4 changes: 4 additions & 0 deletions data_diff/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ class DataDiffDbtCoreNoRunnerError(Exception):

class DataDiffDbtSelectVersionTooLowError(Exception):
"Raised when attempting to use `--select` with a dbt-core version < 1.5."


class DataDiffCustomSchemaNoConfigError(Exception):
"Raised when a model has a custom schema, but there is no prod_custom_schema config. (And not using --state)."
Empty file added tests/cloud/__init__.py
Empty file.
Loading