Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

enhance cloud event metadata #547

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_diff/cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .datafold_api import DatafoldAPI, TCloudApiDataDiff
from .datafold_api import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta
from .data_source import get_or_create_data_source
13 changes: 13 additions & 0 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class TCloudApiDataDiff(pydantic.BaseModel):
exclude_columns: Optional[List[str]]


class TCloudApiOrgMeta(pydantic.BaseModel):
org_id: int
org_name: str
user_id: int


class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):
total_rows: Tuple[int, int]
nulls: Tuple[int, int]
Expand Down Expand Up @@ -276,3 +282,10 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
)
for item in rv.json()["results"]
]

def get_org_meta(self) -> TCloudApiOrgMeta:
response = self.make_get_request(f"api/v1/organization/meta")
response_json = response.json()
return TCloudApiOrgMeta(
org_id=response_json["org_id"], org_name=response_json["org_name"], user_id=response_json["user_id"]
)
10 changes: 7 additions & 3 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import keyring

from .cloud import DatafoldAPI, TCloudApiDataDiff, get_or_create_data_source
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
from .dbt_parser import DbtParser, PROJECT_FILE


Expand Down Expand Up @@ -78,6 +78,7 @@ def dbt_diff(
# exit so the user can set the key
if not api:
return
org_meta = api.get_org_meta()

if datasource_id is None:
rich.print("[red]Data source ID not found in dbt_project.yml")
Expand Down Expand Up @@ -110,7 +111,7 @@ def dbt_diff(

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api)
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars)
Expand Down Expand Up @@ -268,7 +269,7 @@ def _initialize_api() -> Optional[DatafoldAPI]:
return DatafoldAPI(api_key=api_key, host=datafold_host)


def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI) -> None:
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_meta: TCloudApiOrgMeta) -> None:
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
payload = TCloudApiDataDiff(
data_source1_id=datasource_id,
Expand Down Expand Up @@ -356,6 +357,9 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI) -> N
error=err_message,
diff_id=diff_id,
is_cloud=True,
org_id=org_meta.org_id,
org_name=org_meta.org_name,
user_id=org_meta.user_id,
)
send_event_json(event_json)

Expand Down
6 changes: 6 additions & 0 deletions data_diff/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def create_end_event_json(
error: Optional[str],
diff_id: Optional[int] = None,
is_cloud: bool = False,
org_id: Optional[int] = None,
org_name: Optional[str] = None,
user_id: Optional[int] = None,
):
return {
"event": "os_diff_run_end",
Expand All @@ -138,6 +141,9 @@ def create_end_event_json(
"dbt_user_id": dbt_user_id,
"dbt_version": dbt_version,
"dbt_project_id": dbt_project_id,
"org_id": org_id,
"org_name": org_name,
"user_id": user_id,
},
}

Expand Down
44 changes: 31 additions & 13 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

from pathlib import Path

from data_diff.cloud.datafold_api import TCloudApiDataSource
from data_diff.cloud.datafold_api import TCloudApiOrgMeta
from data_diff.diff_tables import Algorithm
from .test_cli import run_datadiff_cli

Expand Down Expand Up @@ -569,6 +569,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
@patch("data_diff.dbt.os.environ")
@patch("data_diff.dbt.DatafoldAPI")
def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
expected_api_key = "an_api_key"
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
Expand All @@ -591,7 +592,7 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
exclude_columns=[],
)

_cloud_diff(diff_vars, expected_datasource_id, api=mock_api)
_cloud_diff(diff_vars, expected_datasource_id, org_meta=org_meta, api=mock_api)

mock_api.create_data_diff.assert_called_once()
self.assertEqual(mock_print.call_count, 2)
Expand All @@ -613,8 +614,16 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
@patch("data_diff.dbt.rich.print")
@patch("data_diff.dbt.DatafoldAPI")
def test_diff_is_cloud(
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api,
self,
mock_api,
mock_print,
mock_dbt_parser,
mock_cloud_diff,
mock_local_diff,
mock_get_diff_vars,
mock_initialize_api,
):
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
connection = {}
threads = None
where = "a_string"
Expand All @@ -627,6 +636,8 @@ def test_diff_is_cloud(
mock_model = Mock()
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")
mock_initialize_api.return_value = mock_api
mock_api.get_org_meta.return_value = org_meta

mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
Expand All @@ -649,7 +660,7 @@ def test_diff_is_cloud(

mock_initialize_api.assert_called_once()
mock_api.get_data_source.assert_called_once_with(1)
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api)
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta)
mock_local_diff.assert_not_called()
mock_print.assert_called_once()

Expand All @@ -663,20 +674,20 @@ def test_diff_is_cloud(
def test_diff_is_cloud_no_ds_id(
self, _, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
):
org_meta = TCloudApiOrgMeta(org_id=1, org_name="", user_id=1)
connection = {}
threads = None
where = "a_string"
host = "a_host"
api_key = "a_api_key"
mock_dbt_parser_inst = Mock()
mock_model = Mock()
expected_dbt_vars_dict = {
"prod_database": "prod_db",
"prod_schema": "prod_schema",
}
mock_api = Mock()
mock_initialize_api.return_value = mock_api
mock_api.get_org_meta.return_value = org_meta

api = DatafoldAPI(api_key=api_key, host=host)
mock_initialize_api.return_value = api
mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
Expand Down Expand Up @@ -827,8 +838,18 @@ def test_diff_only_prod_schema(
@patch("data_diff.dbt.rich.print")
@patch("data_diff.dbt.DatafoldAPI")
def test_diff_is_cloud_no_pks(
self, mock_api, mock_print, mock_dbt_parser, mock_cloud_diff, mock_local_diff, mock_get_diff_vars, mock_initialize_api
self,
mock_api,
mock_print,
mock_dbt_parser,
mock_cloud_diff,
mock_local_diff,
mock_get_diff_vars,
mock_initialize_api,
):
mock_dbt_parser_inst = Mock()
mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_model = Mock()
connection = {}
threads = None
where = "a_string"
Expand All @@ -837,11 +858,8 @@ def test_diff_is_cloud_no_pks(
"prod_schema": "prod_schema",
"datasource_id": 1,
}
mock_dbt_parser_inst = Mock()
mock_dbt_parser.return_value = mock_dbt_parser_inst
mock_model = Mock()
mock_api = Mock()
mock_initialize_api.return_value = mock_api
mock_api.get_data_source.return_value = TCloudApiDataSource(id=1, type="snowflake", name="snowflake")

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
Expand Down