diff --git a/gcp_airflow_foundations/source_class/generic_file_source.py b/gcp_airflow_foundations/source_class/generic_file_source.py index 14813e12..a9cfcb64 100644 --- a/gcp_airflow_foundations/source_class/generic_file_source.py +++ b/gcp_airflow_foundations/source_class/generic_file_source.py @@ -3,6 +3,7 @@ from datetime import datetime import json from dacite import from_dict +from os.path import join from airflow.models.dag import DAG from airflow.operators.python_operator import PythonOperator @@ -172,7 +173,7 @@ def get_list_of_files(self, table_config, **kwargs): # support replacing files with current dates file_list[:] = [file.replace("{{ ds }}", ds) if "{{ ds }}" in file else file for file in file_list] # add dir prefix to files - file_list[:] = [gcs_bucket_prefix + file for file in file_list] + file_list[:] = [join(gcs_bucket_prefix, file) for file in file_list] logging.info(file_list) kwargs['ti'].xcom_push(key='file_list', value=file_list) diff --git a/tests/unit/sources/gcs/__init__.py b/tests/unit/sources/gcs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/sources/gcs/gcs.yaml b/tests/unit/sources/gcs/gcs.yaml new file mode 100644 index 00000000..3dbee057 --- /dev/null +++ b/tests/unit/sources/gcs/gcs.yaml @@ -0,0 +1,35 @@ +source: + name: SampleGCS + source_type: GCS + ingest_schedule: "@daily" + start_date: "2022-01-03" + catchup: False + acceptable_delay_minutes: 5 + extra_options: + gcs_bucket: public-gcp-airflow-foundation-samples + file_source_config: + file_name_template: "{{ TABLE_NAME }}.csv" + source_format: CSV + delimeter: "," + file_prefix_filtering: False + delete_gcs_files: False + sensor_timeout: 6000 + airflow_date_template: "ds" + date_format: "%Y-%m-%d" + gcs_bucket_prefix: "" + gcp_project: airflow-framework + location: us + dataset_data_name: af_test_ods +# dataset_hds_override: af_test_hds + owner: test_user + notification_emails: [] + landing_zone_options: + landing_zone_dataset: af_test_hds_landing_zone +tables: + - table_name: users + ingestion_type: FULL + surrogate_keys: [] + extra_options: + file_table_config: + directory_prefix: "" + allow_quoted_newlines: True diff --git a/tests/unit/sources/gcs/test_gcs.py b/tests/unit/sources/gcs/test_gcs.py new file mode 100644 index 00000000..0a8d28c6 --- /dev/null +++ b/tests/unit/sources/gcs/test_gcs.py @@ -0,0 +1,88 @@ +from genericpath import exists +from keyword import kwlist +import os +import unittest + +from gcp_airflow_foundations.base_class.utils import load_tables_config_from_dir +from gcp_airflow_foundations.source_class import gcs_source +from datetime import datetime +from gcp_airflow_foundations.base_class.file_source_config import FileSourceConfig +from gcp_airflow_foundations.base_class.file_table_config import FileTableConfig +from dacite import from_dict +from airflow.models.dag import DAG +from airflow.models import DagRun, DagTag, TaskInstance, DagModel +from airflow.utils.session import create_session +from airflow.utils.task_group import TaskGroup + +DEFAULT_DATE = datetime(2015, 1, 1) +TEST_DAG_ID = "unit_test_dag" +DEV_NULL = "/dev/null" + + +def clear_db_dags(): + with create_session() as session: + session.query(DagTag).delete() + session.query(DagModel).delete() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + + +class TestGcs(unittest.TestCase): + def setUp(self): + clear_db_dags() + + here = os.path.abspath(os.path.dirname(__file__)) + self.conf_location = here + self.config = next(iter(load_tables_config_from_dir(self.conf_location)), None) + + self.args = {"owner": "airflow", "start_date": DEFAULT_DATE} + + def test_validate_yaml_good(self): + # Validate options + assert self.config.source.source_type == "GCS" + assert self.config.source.name is not None + + # Validating extra options + file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"]) + assert file_source_config.airflow_date_template == "ds" + assert file_source_config.date_format == "%Y-%m-%d" + assert file_source_config.delete_gcs_files is not True + assert file_source_config.gcs_bucket_prefix == "" + + tables = self.config.tables + for table_config in tables: + assert table_config.table_name is not None + file_table_config = from_dict(data_class=FileTableConfig, data=table_config.extra_options.get("file_table_config")) + assert isinstance(file_table_config.directory_prefix, str) + + def test_gcs_file_sensor_good(self): + gcs_dag_builder = gcs_source.GCSFileIngestionDagBuilder( + default_task_args=self.args, config=self.config + ) + + for table_config in self.config.tables: + with DAG( + TEST_DAG_ID, default_args=self.args, schedule_interval="@once" + ) as dag: + task_group = TaskGroup("test", dag=dag) + file_sensor = gcs_dag_builder.file_sensor(table_config, task_group) + assert file_sensor is not None + assert file_sensor.bucket == "public-gcp-airflow-foundation-samples" + assert file_sensor.objects == "{{ ti.xcom_pull(key='file_list', task_ids='ftp_taskgroup.get_file_list') }}" + + def test_gcs_list_files_good(self): + gcs_dag_builder = gcs_source.GCSFileIngestionDagBuilder( + default_task_args=self.args, config=self.config + ) + + for table_config in self.config.tables: + with DAG( + TEST_DAG_ID, default_args=self.args, schedule_interval="@once" + ) as dag: + task_group = TaskGroup("test", dag=dag) + file_sensor = gcs_dag_builder.file_sensor(table_config, task_group) + task_instance = TaskInstance(file_sensor, execution_date=datetime.now()) + task_instance.xcom_push(key='file_list', value='1') + gcs_dag_builder.get_list_of_files(table_config, ds="2022-04-19", ti=task_instance) + file_list = task_instance.xcom_pull(key='file_list') + assert file_list == ["users.csv"]