1
1
import os
2
2
3
3
from pathlib import Path
4
-
5
4
from data_diff .cloud .datafold_api import TCloudApiDataSource
5
+ from data_diff .cloud .datafold_api import TCloudApiOrgMeta
6
6
from data_diff .diff_tables import Algorithm
7
7
from .test_cli import run_datadiff_cli
8
8
@@ -569,6 +569,7 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
569
569
@patch ("data_diff.dbt.os.environ" )
570
570
@patch ("data_diff.dbt.DatafoldAPI" )
571
571
def test_cloud_diff (self , mock_api , mock_os_environ , mock_print ):
572
+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
572
573
expected_api_key = "an_api_key"
573
574
dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
574
575
prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
@@ -591,7 +592,7 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
591
592
exclude_columns = [],
592
593
)
593
594
594
- _cloud_diff (diff_vars , expected_datasource_id , api = mock_api )
595
+ _cloud_diff (diff_vars , expected_datasource_id , org_meta = org_meta , api = mock_api )
595
596
596
597
mock_api .create_data_diff .assert_called_once ()
597
598
self .assertEqual (mock_print .call_count , 2 )
@@ -613,8 +614,16 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
613
614
@patch ("data_diff.dbt.rich.print" )
614
615
@patch ("data_diff.dbt.DatafoldAPI" )
615
616
def test_diff_is_cloud (
616
- self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api ,
617
+ self ,
618
+ mock_api ,
619
+ mock_print ,
620
+ mock_dbt_parser ,
621
+ mock_cloud_diff ,
622
+ mock_local_diff ,
623
+ mock_get_diff_vars ,
624
+ mock_initialize_api ,
617
625
):
626
+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
618
627
connection = {}
619
628
threads = None
620
629
where = "a_string"
@@ -627,6 +636,8 @@ def test_diff_is_cloud(
627
636
mock_model = Mock ()
628
637
mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
629
638
mock_initialize_api .return_value = mock_api
639
+ mock_api .get_org_meta .return_value = org_meta
640
+
630
641
mock_dbt_parser .return_value = mock_dbt_parser_inst
631
642
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
632
643
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -649,7 +660,7 @@ def test_diff_is_cloud(
649
660
650
661
mock_initialize_api .assert_called_once ()
651
662
mock_api .get_data_source .assert_called_once_with (1 )
652
- mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api )
663
+ mock_cloud_diff .assert_called_once_with (diff_vars , 1 , mock_api , org_meta )
653
664
mock_local_diff .assert_not_called ()
654
665
mock_print .assert_called_once ()
655
666
@@ -663,20 +674,20 @@ def test_diff_is_cloud(
663
674
def test_diff_is_cloud_no_ds_id (
664
675
self , _ , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
665
676
):
677
+ org_meta = TCloudApiOrgMeta (org_id = 1 , org_name = "" , user_id = 1 )
666
678
connection = {}
667
679
threads = None
668
680
where = "a_string"
669
- host = "a_host"
670
- api_key = "a_api_key"
671
681
mock_dbt_parser_inst = Mock ()
672
682
mock_model = Mock ()
673
683
expected_dbt_vars_dict = {
674
684
"prod_database" : "prod_db" ,
675
685
"prod_schema" : "prod_schema" ,
676
686
}
687
+ mock_api = Mock ()
688
+ mock_initialize_api .return_value = mock_api
689
+ mock_api .get_org_meta .return_value = org_meta
677
690
678
- api = DatafoldAPI (api_key = api_key , host = host )
679
- mock_initialize_api .return_value = api
680
691
mock_dbt_parser .return_value = mock_dbt_parser_inst
681
692
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
682
693
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -827,8 +838,18 @@ def test_diff_only_prod_schema(
827
838
@patch ("data_diff.dbt.rich.print" )
828
839
@patch ("data_diff.dbt.DatafoldAPI" )
829
840
def test_diff_is_cloud_no_pks (
830
- self , mock_api , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars , mock_initialize_api
841
+ self ,
842
+ mock_api ,
843
+ mock_print ,
844
+ mock_dbt_parser ,
845
+ mock_cloud_diff ,
846
+ mock_local_diff ,
847
+ mock_get_diff_vars ,
848
+ mock_initialize_api ,
831
849
):
850
+ mock_dbt_parser_inst = Mock ()
851
+ mock_dbt_parser .return_value = mock_dbt_parser_inst
852
+ mock_model = Mock ()
832
853
connection = {}
833
854
threads = None
834
855
where = "a_string"
@@ -837,11 +858,8 @@ def test_diff_is_cloud_no_pks(
837
858
"prod_schema" : "prod_schema" ,
838
859
"datasource_id" : 1 ,
839
860
}
840
- mock_dbt_parser_inst = Mock ()
841
- mock_dbt_parser .return_value = mock_dbt_parser_inst
842
- mock_model = Mock ()
861
+ mock_api = Mock ()
843
862
mock_initialize_api .return_value = mock_api
844
- mock_api .get_data_source .return_value = TCloudApiDataSource (id = 1 , type = "snowflake" , name = "snowflake" )
845
863
846
864
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
847
865
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
0 commit comments