Skip to content

Commit

Permalink
feat: added OL support for AzureBlobStorageToGCSOperator in google pr…
Browse files Browse the repository at this point in the history
…ovider package (#40290)

* added get_openlineage_facets_on_start function in operator definition

* updated account name fetch from connection and added test

* added account_name test in azure for blob

* removed gcs hook and execute() from test

* used pre-commit
  • Loading branch information
rahul-madaan authored Jun 20, 2024
1 parent 3b6fba9 commit 19fc7d2
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
13 changes: 13 additions & 0 deletions airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,16 @@ def execute(self, context: Context) -> str:
self.bucket_name,
)
return f"gs://{self.bucket_name}/{self.object_name}"

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
account_name = wasb_hook.get_conn().account_name

return OperatorLineage(
inputs=[Dataset(namespace=f"wasbs://{self.container_name}@{account_name}", name=self.blob_name)],
outputs=[Dataset(namespace=f"gs://{self.bucket_name}", name=self.object_name)],
)
29 changes: 29 additions & 0 deletions tests/providers/google/cloud/transfers/test_azure_blob_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,32 @@ def test_execute(self, mock_temp, mock_hook_gcs, mock_hook_wasb):
gzip=GZIP,
filename=mock_temp.NamedTemporaryFile.return_value.__enter__.return_value.name,
)

@mock.patch("airflow.providers.google.cloud.transfers.azure_blob_to_gcs.WasbHook")
def test_execute_single_file_transfer_openlineage(self, mock_hook_wasb):
from openlineage.client.run import Dataset

MOCK_AZURE_ACCOUNT_NAME = "mock_account_name"
mock_hook_wasb.return_value.get_conn.return_value.account_name = MOCK_AZURE_ACCOUNT_NAME

operator = AzureBlobStorageToGCSOperator(
wasb_conn_id=WASB_CONN_ID,
gcp_conn_id=GCP_CONN_ID,
blob_name=BLOB_NAME,
container_name=CONTAINER_NAME,
bucket_name=BUCKET_NAME,
object_name=OBJECT_NAME,
filename=FILENAME,
gzip=GZIP,
impersonation_chain=IMPERSONATION_CHAIN,
task_id=TASK_ID,
)

lineage = operator.get_openlineage_facets_on_start()

assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == Dataset(
namespace=f"wasbs://{CONTAINER_NAME}@{MOCK_AZURE_ACCOUNT_NAME}", name=BLOB_NAME
)
assert lineage.outputs[0] == Dataset(namespace=f"gs://{BUCKET_NAME}", name=OBJECT_NAME)
30 changes: 30 additions & 0 deletions tests/providers/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,33 @@ def test_connection_failure(self, mocked_blob_service_client):
status, msg = hook.test_connection()
assert status is False
assert msg == "Authentication failed."

@pytest.mark.parametrize(
"conn_id_str",
[
"wasb_test_key",
"pub_read_id",
"pub_read_id_without_host",
"azure_test_connection_string",
"azure_shared_key_test",
"ad_conn_id",
"managed_identity_conn_id",
"sas_conn_id",
"extra__wasb__sas_conn_id",
"http_sas_conn_id",
"extra__wasb__http_sas_conn_id",
],
)
def test_extract_account_name_from_connection(self, conn_id_str, mocked_blob_service_client):
expected_account_name = "testname"
if conn_id_str == "azure_test_connection_string":
mocked_blob_service_client.from_connection_string().account_name = expected_account_name
else:
mocked_blob_service_client.return_value.account_name = expected_account_name

wasb_hook = WasbHook(wasb_conn_id=conn_id_str)
account_name = wasb_hook.get_conn().account_name

assert (
account_name == expected_account_name
), f"Expected account name {expected_account_name} but got {account_name}"

0 comments on commit 19fc7d2

Please # to comment.