From 19fc7d2e7336d31d2c7d01097096fb50f14b52f5 Mon Sep 17 00:00:00 2001 From: Rahul Madan <34760210+rahul-madaan@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:26:23 +0530 Subject: [PATCH] feat: added OL support for AzureBlobStorageToGCSOperator in google provider package (#40290) * added get_openlineage_facets_on_start function in operator definition * updated account name fetch from connection and added test * added account_name test in azure for blob * removed gcs hook and execute() from test * used pre-commit --- .../cloud/transfers/azure_blob_to_gcs.py | 13 ++++++++ .../cloud/transfers/test_azure_blob_to_gcs.py | 29 ++++++++++++++++++ .../microsoft/azure/hooks/test_wasb.py | 30 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py index 1da9e82c09247..683bfcdda8f6d 100644 --- a/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py @@ -122,3 +122,16 @@ def execute(self, context: Context) -> str: self.bucket_name, ) return f"gs://{self.bucket_name}/{self.object_name}" + + def get_openlineage_facets_on_start(self): + from openlineage.client.run import Dataset + + from airflow.providers.openlineage.extractors import OperatorLineage + + wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + account_name = wasb_hook.get_conn().account_name + + return OperatorLineage( + inputs=[Dataset(namespace=f"wasbs://{self.container_name}@{account_name}", name=self.blob_name)], + outputs=[Dataset(namespace=f"gs://{self.bucket_name}", name=self.object_name)], + ) diff --git a/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py b/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py index a0b3eae99d683..b71d747ebf7d0 100644 --- a/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py @@ -91,3 +91,32 @@ def test_execute(self, mock_temp, mock_hook_gcs, mock_hook_wasb): gzip=GZIP, filename=mock_temp.NamedTemporaryFile.return_value.__enter__.return_value.name, ) + + @mock.patch("airflow.providers.google.cloud.transfers.azure_blob_to_gcs.WasbHook") + def test_execute_single_file_transfer_openlineage(self, mock_hook_wasb): + from openlineage.client.run import Dataset + + MOCK_AZURE_ACCOUNT_NAME = "mock_account_name" + mock_hook_wasb.return_value.get_conn.return_value.account_name = MOCK_AZURE_ACCOUNT_NAME + + operator = AzureBlobStorageToGCSOperator( + wasb_conn_id=WASB_CONN_ID, + gcp_conn_id=GCP_CONN_ID, + blob_name=BLOB_NAME, + container_name=CONTAINER_NAME, + bucket_name=BUCKET_NAME, + object_name=OBJECT_NAME, + filename=FILENAME, + gzip=GZIP, + impersonation_chain=IMPERSONATION_CHAIN, + task_id=TASK_ID, + ) + + lineage = operator.get_openlineage_facets_on_start() + + assert len(lineage.inputs) == 1 + assert len(lineage.outputs) == 1 + assert lineage.inputs[0] == Dataset( + namespace=f"wasbs://{CONTAINER_NAME}@{MOCK_AZURE_ACCOUNT_NAME}", name=BLOB_NAME + ) + assert lineage.outputs[0] == Dataset(namespace=f"gs://{BUCKET_NAME}", name=OBJECT_NAME) diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 51b4c94a9213e..39c4b7505250d 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -615,3 +615,33 @@ def test_connection_failure(self, mocked_blob_service_client): status, msg = hook.test_connection() assert status is False assert msg == "Authentication failed." + + @pytest.mark.parametrize( + "conn_id_str", + [ + "wasb_test_key", + "pub_read_id", + "pub_read_id_without_host", + "azure_test_connection_string", + "azure_shared_key_test", + "ad_conn_id", + "managed_identity_conn_id", + "sas_conn_id", + "extra__wasb__sas_conn_id", + "http_sas_conn_id", + "extra__wasb__http_sas_conn_id", + ], + ) + def test_extract_account_name_from_connection(self, conn_id_str, mocked_blob_service_client): + expected_account_name = "testname" + if conn_id_str == "azure_test_connection_string": + mocked_blob_service_client.from_connection_string().account_name = expected_account_name + else: + mocked_blob_service_client.return_value.account_name = expected_account_name + + wasb_hook = WasbHook(wasb_conn_id=conn_id_str) + account_name = wasb_hook.get_conn().account_name + + assert ( + account_name == expected_account_name + ), f"Expected account name {expected_account_name} but got {account_name}"