Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1cf89ac

Browse files
authored
Merge pull request #547 from dlawin/daniel-dx-666-improve-dbt-cloud-event-metadata
enhance cloud event metadata
2 parents 2d5db0f + ad1ad26 commit 1cf89ac

File tree

5 files changed

+58
-17
lines changed

5 files changed

+58
-17
lines changed

data_diff/cloud/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .datafold_api import DatafoldAPI, TCloudApiDataDiff
1+
from .datafold_api import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
22
from .data_source import get_or_create_data_source

data_diff/cloud/datafold_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ class TCloudApiDataDiff(pydantic.BaseModel):
107107
exclude_columns: Optional[List[str]]
108108

109109

110+
class TCloudApiOrgMeta(pydantic.BaseModel):
111+
org_id: int
112+
org_name: str
113+
user_id: int
114+
115+
110116
class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
111117
total_rows: Tuple[int, int]
112118
nulls: Tuple[int, int]
@@ -276,3 +282,10 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
276282
)
277283
for item in rv.json()["results"]
278284
]
285+
286+
def get_org_meta(self) -> TCloudApiOrgMeta:
287+
response = self.make_get_request(f"api/v1/organization/meta")
288+
response_json = response.json()
289+
return TCloudApiOrgMeta(
290+
org_id=response_json["org_id"], org_name=response_json["org_name"], user_id=response_json["user_id"]
291+
)

data_diff/dbt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import keyring
1919

20-
from .cloud import DatafoldAPI, TCloudApiDataDiff, get_or_create_data_source
20+
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
2121
from .dbt_parser import DbtParser, PROJECT_FILE
2222

2323

@@ -78,6 +78,7 @@ def dbt_diff(
7878
# exit so the user can set the key
7979
if not api:
8080
return
81+
org_meta = api.get_org_meta()
8182

8283
if datasource_id is None:
8384
rich.print("[red]Data source ID not found in dbt_project.yml")
@@ -110,7 +111,7 @@ def dbt_diff(
110111

111112
if diff_vars.primary_keys:
112113
if is_cloud:
113-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api)
114+
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
114115
diff_threads.append(diff_thread)
115116
else:
116117
_local_diff(diff_vars)
@@ -268,7 +269,7 @@ def _initialize_api() -> Optional[DatafoldAPI]:
268269
return DatafoldAPI(api_key=api_key, host=datafold_host)
269270

270271

271-
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI) -> None:
272+
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_meta: TCloudApiOrgMeta) -> None:
272273
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
273274
payload = TCloudApiDataDiff(
274275
data_source1_id=datasource_id,
@@ -356,6 +357,9 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI) -> N
356357
error=err_message,
357358
diff_id=diff_id,
358359
is_cloud=True,
360+
org_id=org_meta.org_id,
361+
org_name=org_meta.org_name,
362+
user_id=org_meta.user_id,
359363
)
360364
send_event_json(event_json)
361365

data_diff/tracking.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def create_end_event_json(
116116
error: Optional[str],
117117
diff_id: Optional[int] = None,
118118
is_cloud: bool = False,
119+
org_id: Optional[int] = None,
120+
org_name: Optional[str] = None,
121+
user_id: Optional[int] = None,
119122
):
120123
return {
121124
"event": "os_diff_run_end",
@@ -138,6 +141,9 @@ def create_end_event_json(
138141
"dbt_user_id": dbt_user_id,
139142
"dbt_version": dbt_version,
140143
"dbt_project_id": dbt_project_id,
144+
"org_id": org_id,
145+
"org_name": org_name,
146+
"user_id": user_id,
141147
},
142148
}
143149

tests/test_dbt.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22

33
from pathlib import Path
4-
54
from data_diff.cloud.datafold_api import TCloudApiDataSource
5+
from data_diff.cloud.datafold_api import TCloudApiOrgMeta
66
from data_diff.diff_tables import Algorithm
77
from .test_cli import run_datadiff_cli
88

@@ -569,6 +569,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
569569
@patch("data_diff.dbt.os.environ")
570570
@patch("data_diff.dbt.DatafoldAPI")
571571
def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
572+
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
572573
expected_api_key = "an_api_key"
573574
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
574575
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
@@ -591,7 +592,7 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
591592
exclude_columns=[],
592593
)
593594

594-
_cloud_diff(diff_vars, expected_datasource_id, api=mock_api)
595+
_cloud_diff(diff_vars, expected_datasource_id, org_meta=org_meta, api=mock_api)
595596

596597
mock_api.create_data_diff.assert_called_once()
597598
self.assertEqual(mock_print.call_count, 2)
@@ -613,8 +614,16 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
613614
@patch("data_diff.dbt.rich.print")
614615
@patch("data_diff.dbt.DatafoldAPI")
615616
def test_diff_is_cloud(
616-
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api,
617+
self,
618+
mock_api,
619+
mock_print,
620+
mock_dbt_parser,
621+
mock_cloud_diff,
622+
mock_local_diff,
623+
mock_get_diff_vars,
624+
mock_initialize_api,
617625
):
626+
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
618627
connection = {}
619628
threads = None
620629
where = "a_string"
@@ -627,6 +636,8 @@ def test_diff_is_cloud(
627636
mock_model = Mock()
628637
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
629638
mock_initialize_api.return_value = mock_api
639+
mock_api.get_org_meta.return_value = org_meta
640+
630641
mock_dbt_parser.return_value = mock_dbt_parser_inst
631642
mock_dbt_parser_inst.get_models.return_value = [mock_model]
632643
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -649,7 +660,7 @@ def test_diff_is_cloud(
649660

650661
mock_initialize_api.assert_called_once()
651662
mock_api.get_data_source.assert_called_once_with(1)
652-
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api)
663+
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta)
653664
mock_local_diff.assert_not_called()
654665
mock_print.assert_called_once()
655666

@@ -663,20 +674,20 @@ def test_diff_is_cloud(
663674
def test_diff_is_cloud_no_ds_id(
664675
self, _, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
665676
):
677+
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
666678
connection = {}
667679
threads = None
668680
where = "a_string"
669-
host = "a_host"
670-
api_key = "a_api_key"
671681
mock_dbt_parser_inst = Mock()
672682
mock_model = Mock()
673683
expected_dbt_vars_dict = {
674684
"prod_database": "prod_db",
675685
"prod_schema": "prod_schema",
676686
}
687+
mock_api = Mock()
688+
mock_initialize_api.return_value = mock_api
689+
mock_api.get_org_meta.return_value = org_meta
677690

678-
api = DatafoldAPI(api_key=api_key, host=host)
679-
mock_initialize_api.return_value = api
680691
mock_dbt_parser.return_value = mock_dbt_parser_inst
681692
mock_dbt_parser_inst.get_models.return_value = [mock_model]
682693
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -827,8 +838,18 @@ def test_diff_only_prod_schema(
827838
@patch("data_diff.dbt.rich.print")
828839
@patch("data_diff.dbt.DatafoldAPI")
829840
def test_diff_is_cloud_no_pks(
830-
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
841+
self,
842+
mock_api,
843+
mock_print,
844+
mock_dbt_parser,
845+
mock_cloud_diff,
846+
mock_local_diff,
847+
mock_get_diff_vars,
848+
mock_initialize_api,
831849
):
850+
mock_dbt_parser_inst = Mock()
851+
mock_dbt_parser.return_value = mock_dbt_parser_inst
852+
mock_model = Mock()
832853
connection = {}
833854
threads = None
834855
where = "a_string"
@@ -837,11 +858,8 @@ def test_diff_is_cloud_no_pks(
837858
"prod_schema": "prod_schema",
838859
"datasource_id": 1,
839860
}
840-
mock_dbt_parser_inst = Mock()
841-
mock_dbt_parser.return_value = mock_dbt_parser_inst
842-
mock_model = Mock()
861+
mock_api = Mock()
843862
mock_initialize_api.return_value = mock_api
844-
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
845863

846864
mock_dbt_parser_inst.get_models.return_value = [mock_model]
847865
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict

0 commit comments

Comments
 (0)