Skip to content

Commit

Permalink
Merge pull request #82 from badal-io/hotfix/gcs-file-sources
Browse files Browse the repository at this point in the history
fix gcs inputs
  • Loading branch information
sharvenm authored May 2, 2022
2 parents 9f6a516 + 71c59a6 commit 7b8cb2b
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 30 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,31 @@ pip install 'gcp-airflow-foundations'
See the [gcp-airflow-foundations documentation](https://gcp-airflow-foundations.readthedocs.io/en/latest/) for more details.

## Running locally

### Sample DAGs
Sample DAGs that ingest publicly available GCS files can be found in the dags folder, and are started as soon Airflow is ran locally. In order to have them successfully run please ensure the following:
- Enable: BigQuery, Cloud Storage, Cloud DLP, Data Catalog API's
- Create a BigQuery Dataset for the HDS and ODS
- Create a DLP Inspect template in DLP
- Create a policy tag in Data Catalog
- Update the gcp_project, location, dataset values, dlp config and policytag configs with your newly created values

### Using Service Account
- Create a service account in GCP, and save it as ```helpers/key/keys.json``` (don't worry, it is in .gitignore, and will not be push to the git repo)
- Run Airflow locally (Airflow UI will be accessible at http://localhost:8080): ```docker-compose up```
- Default authentication values for the Airflow UI are provided in lines 96, 97 of ```docker-composer.yaml```
### Using user IAM
- uncomment like 11 in ```docker-composer.yaml```
- uncomment line 11 in ```docker-composer.yaml```
- send env var PROJECT_ID to your test project
- Authorize gcloud to access the Cloud Platform with Google user credentials: ```helpers/scripts/gcp-auth.sh```
- Run Airflow locally (Airflow UI will be accessible at http://localhost:8080): ```docker-compose up```
- Default authentication values for the Airflow UI are provided in lines 96, 97 of ```docker-composer.yaml```
### Running tests
- Run unit tests ```./tests/airflow "pytest tests/unit```
- Run unit tests with coverage report ```./tests/airflow "pytest --cov=gcp_airflow_foundations tests/unit```
- Run integration tests ```./tests/airflow "pytest tests/integration```
- Rebuild docker image if requirements changed: ```docker-compose build```
### Sample DAGs
Sample DAGs that ingest publicly available GCS files can be found in the dags folder, and are started as soon Airflow is ran locally

# Contributing
## Install pre-commit hook
Install pre-commit hooks for linting, format checking, etc.
Expand Down
2 changes: 1 addition & 1 deletion dags/config/gcs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ tables:
extra_options:
file_table_config:
directory_prefix: ""
allow_quoted_newlines: True
allow_quoted_newlines: True
2 changes: 0 additions & 2 deletions dags/config/gcs_dlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ tables:
file_table_config:
directory_prefix: ""
allow_quoted_newlines: True


2 changes: 0 additions & 2 deletions dags/config/gcs_hds_dlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,3 @@ tables:
file_table_config:
directory_prefix: ""
allow_quoted_newlines: True


6 changes: 3 additions & 3 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ x-airflow-common:
&airflow-common-env
# uncomment to allow using user IAM for access
#AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: 'google-cloud-platform://user:password@host/schema?extra__google_cloud_platform__scope=${GCP_AUTH_SCOPE}&extra__google_cloud_platform__project=${GCP_PROJECT_ID}'
AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow
AIRFLOW__CORE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow1!@postgres/airflow
env_file:
- ./variables/docker-env-vars
- ./variables/docker-env-secrets # added to gitignore
Expand All @@ -34,7 +34,7 @@ services:
image: postgres:13
environment:
POSTGRES_USER: airflow
POSTGRES_PASSWORD: airflow
POSTGRES_PASSWORD: airflow1!
POSTGRES_DB: airflow
volumes:
- postgres-db-volume:/var/lib/postgresql/data
Expand Down Expand Up @@ -94,7 +94,7 @@ services:
_AIRFLOW_DB_UPGRADE: 'true'
_AIRFLOW_WWW_USER_CREATE: 'true'
_AIRFLOW_WWW_USER_USERNAME: ${_AIRFLOW_WWW_USER_USERNAME:-airflow}
_AIRFLOW_WWW_USER_PASSWORD: ${_AIRFLOW_WWW_USER_PASSWORD:-airflow}
_AIRFLOW_WWW_USER_PASSWORD: ${_AIRFLOW_WWW_USER_PASSWORD:-airflow1!}
user: "0:${AIRFLOW_GID:-0}"

volumes:
Expand Down
4 changes: 2 additions & 2 deletions gcp_airflow_foundations/base_class/file_source_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class FileSourceConfig:

date_format: str = "%Y-%m-%d"
airflow_date_template: str = "ds"
delete_gcs_files: bool = True
gcs_bucket_prefix: str = False
delete_gcs_files: bool = False
gcs_bucket_prefix: str = ""
delimeter: str = ","
sensor_timeout: int = 10800
41 changes: 24 additions & 17 deletions gcp_airflow_foundations/source_class/generic_file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
import json
from dacite import from_dict
from os.path import join

from airflow.models.dag import DAG
from airflow.operators.python_operator import PythonOperator
Expand All @@ -13,6 +14,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.sensors.gcs import GCSObjectExistenceSensor
from airflow.operators.dummy import DummyOperator
from gcp_airflow_foundations.base_class import file_source_config

from gcp_airflow_foundations.source_class.source import DagBuilder
from gcp_airflow_foundations.base_class.file_source_config import FileSourceConfig
Expand All @@ -30,6 +32,7 @@ def set_schema_method_type(self):

def get_bq_ingestion_task(self, dag, table_config):
taskgroup = TaskGroup(group_id="ftp_taskgroup")
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])

tasks = []

Expand All @@ -47,7 +50,7 @@ def get_bq_ingestion_task(self, dag, table_config):

tasks.append(self.load_to_landing_task(table_config, taskgroup))

if self.config.source.extra_options["file_source_config"]["delete_gcs_files"]:
if file_source_config.delete_gcs_files:
tasks.append(self.delete_gcs_files(table_config, taskgroup))

for task in tasks:
Expand All @@ -66,10 +69,11 @@ def metadata_file_sensor(self, table_config, taskgroup):
Implements a sensor for either the metadata file specified in the table config, which specifies
the parameterized file names to ingest.
"""
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])
if "metadata_file" in table_config.extra_options.get("file_table_config"):
metadata_file_name = table_config.extra_options.get("file_table_config")["metadata_file"]
bucket = self.config.source.extra_options["gcs_bucket"]
timeout = self.config.source.extra_options["file_source_config"]["sensor_timeout"]
timeout = file_source_config.sensor_timeout

return GCSObjectExistenceSensor(
task_id="wait_for_metadata_file",
Expand All @@ -93,9 +97,10 @@ def schema_file_sensor(self, table_config, taskgroup):
"""
Implements an Airflow sensor to wait for an (optional) schema file in GCS
"""
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])
bucket = self.config.source.extra_options["gcs_bucket"]
schema_file_name = None
timeout = self.config.source.extra_options["file_source_config"]["sensor_timeout"]
timeout = file_source_config.sensor_timeout
if "schema_file" in table_config.extra_options.get("file_table_config"):
schema_file_name = table_config.extra_options.get("file_table_config")["schema_file"]
return GCSObjectExistenceSensor(
Expand Down Expand Up @@ -137,21 +142,22 @@ def get_file_list_task(self, table_config, taskgroup):

def get_list_of_files(self, table_config, **kwargs):
# gcs_hook = GCSHook()
airflow_date_template = self.config.source.extra_options["file_source_config"]["airflow_date_template"]
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])
airflow_date_template = file_source_config.airflow_date_template
if airflow_date_template == "ds":
ds = kwargs["ds"]
else:
ds = kwargs["prev_ds"]
ds = datetime.strptime(ds, "%Y-%m-%d").strftime(self.config.source.extra_options["file_source_config"]["date_format"])
ds = datetime.strptime(ds, "%Y-%m-%d").strftime(file_source_config.date_format)
logging.info(ds)
# XCom push the list of files
# overwrite if in table_config
dir_prefix = table_config.extra_options.get("file_table_config")["directory_prefix"]
dir_prefix = dir_prefix.replace("{{ ds }}", ds)

gcs_bucket_prefix = self.config.source.extra_options["file_source_config"]["gcs_bucket_prefix"]
gcs_bucket_prefix = file_source_config.gcs_bucket_prefix

if self.config.source.extra_options["file_source_config"]["source_format"] == "PARQUET":
if file_source_config.source_format == "PARQUET":
file_list = [dir_prefix]
kwargs['ti'].xcom_push(key='file_list', value=file_list)
return
Expand All @@ -165,14 +171,14 @@ def get_list_of_files(self, table_config, **kwargs):
for line in f:
file_list.append(line.strip())
else:
templated_file_name = self.config.source.extra_options["file_source_config"]["file_name_template"]
templated_file_name = file_source_config.file_name_template
templated_file_name = templated_file_name.replace("{{ TABLE_NAME }}", table_config.table_name)
file_list = [templated_file_name]

# support replacing files with current dates
file_list[:] = [file.replace("{{ ds }}", ds) if "{{ ds }}" in file else file for file in file_list]
# add dir prefix to files
file_list[:] = [gcs_bucket_prefix + "/" + file for file in file_list]
file_list[:] = [join(gcs_bucket_prefix, file) for file in file_list]
logging.info(file_list)

kwargs['ti'].xcom_push(key='file_list', value=file_list)
Expand All @@ -188,15 +194,16 @@ def load_to_landing_task(self, table_config, taskgroup):
# flake8: noqa: C901
def load_to_landing(self, table_config, **kwargs):
gcs_hook = GCSHook()
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])

# Parameters
ds = kwargs['ds']
ti = kwargs['ti']

data_source = self.config.source
bucket = data_source.extra_options["gcs_bucket"]
source_format = data_source.extra_options["file_source_config"]["source_format"]
field_delimeter = data_source.extra_options["file_source_config"]["delimeter"]
source_format = file_source_config.source_format
field_delimeter = file_source_config.delimeter
gcp_project = data_source.gcp_project
landing_dataset = data_source.landing_zone_options.landing_zone_dataset
landing_table_name = table_config.landing_zone_table_name_override
Expand All @@ -210,7 +217,7 @@ def load_to_landing(self, table_config, **kwargs):
dir_prefix = dir_prefix.replace("{{ ds }}", ds)
files_to_load = [dir_prefix]

gcs_bucket_prefix = data_source.extra_options["file_source_config"]["gcs_bucket_prefix"]
gcs_bucket_prefix = file_source_config.gcs_bucket_prefix
if gcs_bucket_prefix is None:
gcs_bucket_prefix = ""
if not gcs_bucket_prefix == "":
Expand All @@ -222,18 +229,18 @@ def load_to_landing(self, table_config, **kwargs):
destination_path_prefix = gcs_bucket_prefix + table_name + "/" + date
logging.info(destination_path_prefix)

files_to_load = [destination_path_prefix + "/" + f for f in files_to_load]
logging.info(files_to_load)
files_to_load = [destination_path_prefix + "/" + f for f in files_to_load]
logging.info(files_to_load)

if "parquet_upload_option" in table_config.extra_options.get("file_table_config"):
parquet_upload_option = table_config.extra_options.get("file_table_config")["parquet_upload_option"]
else:
parquet_upload_option = "BASH"

source_format = self.config.source.extra_options["file_source_config"]["source_format"]
source_format = file_source_config.source_format
if source_format == "PARQUET" and parquet_upload_option == "BASH":
date_column = table_config.extra_options.get("sftp_table_config")["date_column"]
gcs_bucket_prefix = data_source.extra_options["file_source_config"]["gcs_bucket_prefix"]
gcs_bucket_prefix = file_source_config.gcs_bucket_prefix
# bq load command if parquet
partition_prefix = ti.xcom_pull(key='partition_prefix', task_ids='ftp_taskgroup.load_sftp_to_gcs')
if not partition_prefix:
Expand All @@ -254,7 +261,7 @@ def load_to_landing(self, table_config, **kwargs):
logging.info("Load into BQ landing zone failed.")
else:
# gcs->bq operator else
if self.config.source.extra_options["file_source_config"]["file_prefix_filtering"]:
if file_source_config.file_prefix_filtering:
logging.info(files_to_load)
for i in range(len(files_to_load)):
matching_gcs_files = gcs_hook.list(bucket_name=bucket, prefix=files_to_load[i])
Expand Down
Empty file.
32 changes: 32 additions & 0 deletions tests/unit/sources/gcs/gcs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
source:
name: SampleGCS
source_type: GCS
ingest_schedule: "@daily"
start_date: "2022-01-03"
catchup: False
acceptable_delay_minutes: 5
extra_options:
gcs_bucket: public-gcp-airflow-foundation-samples
file_source_config:
file_name_template: "{{ TABLE_NAME }}.csv"
source_format: CSV
delimeter: ","
file_prefix_filtering: False
delete_gcs_files: False
sensor_timeout: 6000
gcp_project: airflow-framework
location: us
dataset_data_name: af_test_ods
# dataset_hds_override: af_test_hds
owner: test_user
notification_emails: []
landing_zone_options:
landing_zone_dataset: af_test_hds_landing_zone
tables:
- table_name: users
ingestion_type: FULL
surrogate_keys: []
extra_options:
file_table_config:
directory_prefix: ""
allow_quoted_newlines: True
88 changes: 88 additions & 0 deletions tests/unit/sources/gcs/test_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from genericpath import exists
from keyword import kwlist
import os
import unittest

from gcp_airflow_foundations.base_class.utils import load_tables_config_from_dir
from gcp_airflow_foundations.source_class import gcs_source
from datetime import datetime
from gcp_airflow_foundations.base_class.file_source_config import FileSourceConfig
from gcp_airflow_foundations.base_class.file_table_config import FileTableConfig
from dacite import from_dict
from airflow.models.dag import DAG
from airflow.models import DagRun, DagTag, TaskInstance, DagModel
from airflow.utils.session import create_session
from airflow.utils.task_group import TaskGroup

DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
DEV_NULL = "/dev/null"


def clear_db_dags():
with create_session() as session:
session.query(DagTag).delete()
session.query(DagModel).delete()
session.query(DagRun).delete()
session.query(TaskInstance).delete()


class TestGcs(unittest.TestCase):
def setUp(self):
clear_db_dags()

here = os.path.abspath(os.path.dirname(__file__))
self.conf_location = here
self.config = next(iter(load_tables_config_from_dir(self.conf_location)), None)

self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}

def test_validate_yaml_good(self):
# Validate options
assert self.config.source.source_type == "GCS"
assert self.config.source.name is not None

# Validating extra options
file_source_config = from_dict(data_class=FileSourceConfig, data=self.config.source.extra_options["file_source_config"])
assert file_source_config.airflow_date_template == "ds"
assert file_source_config.date_format == "%Y-%m-%d"
assert file_source_config.delete_gcs_files is not True
assert file_source_config.gcs_bucket_prefix == ""

tables = self.config.tables
for table_config in tables:
assert table_config.table_name is not None
file_table_config = from_dict(data_class=FileTableConfig, data=table_config.extra_options.get("file_table_config"))
assert isinstance(file_table_config.directory_prefix, str)

def test_gcs_file_sensor_good(self):
gcs_dag_builder = gcs_source.GCSFileIngestionDagBuilder(
default_task_args=self.args, config=self.config
)

for table_config in self.config.tables:
with DAG(
TEST_DAG_ID, default_args=self.args, schedule_interval="@once"
) as dag:
task_group = TaskGroup("test", dag=dag)
file_sensor = gcs_dag_builder.file_sensor(table_config, task_group)
assert file_sensor is not None
assert file_sensor.bucket == "public-gcp-airflow-foundation-samples"
assert file_sensor.objects == "{{ ti.xcom_pull(key='file_list', task_ids='ftp_taskgroup.get_file_list') }}"

def test_gcs_list_files_good(self):
gcs_dag_builder = gcs_source.GCSFileIngestionDagBuilder(
default_task_args=self.args, config=self.config
)

for table_config in self.config.tables:
with DAG(
TEST_DAG_ID, default_args=self.args, schedule_interval="@once"
) as dag:
task_group = TaskGroup("test", dag=dag)
file_sensor = gcs_dag_builder.file_sensor(table_config, task_group)
task_instance = TaskInstance(file_sensor, execution_date=datetime.now())
task_instance.xcom_push(key='file_list', value='1')
gcs_dag_builder.get_list_of_files(table_config, ds="2022-04-19", ti=task_instance)
file_list = task_instance.xcom_pull(key='file_list')
assert file_list == ["users.csv"]

0 comments on commit 7b8cb2b

Please # to comment.