From 1e94cfe337eb233e8ee226ef877a9b4ac2d49af0 Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Sun, 7 Jan 2024 14:55:27 +0800 Subject: [PATCH 1/7] [api] get dag and dagRun can return the specified field --- .../api_connexion/endpoints/dag_endpoint.py | 41 +++++++++++++++---- .../endpoints/dag_run_endpoint.py | 24 ++++++++++- airflow/api_connexion/openapi/v1.yaml | 20 +++++++++ 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 3486244b39cc8..aa32e24175b33 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -30,7 +30,9 @@ from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.dag_schema import ( DAGCollection, - dag_detail_schema, + DAGCollectionSchema, + DAGDetailSchema, + DAGSchema, dag_schema, dags_collection_schema, ) @@ -50,19 +52,32 @@ @security.requires_access_dag("GET") @provide_session -def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dag( + *, + dag_id: str, + fields: Collection[str] | None = None, + session: Session = NEW_SESSION +) -> APIResponse: """Get basic information about a DAG.""" dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - + print(f"{fields=}") if dag is None: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") - - return dag_schema.dump(dag) + try: + dag_schema = DAGSchema(only = fields) if fields else DAGSchema() + except ValueError as e: + raise BadRequest("DAGSchema init error", detail=str(e)) + return dag_schema.dump(dag, ) @security.requires_access_dag("GET") @provide_session -def get_dag_details(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dag_details( + *, + dag_id: str, + fields: Collection[str] | None = None, + session: Session = NEW_SESSION +) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) if not dag: @@ -71,7 +86,10 @@ def get_dag_details(*, dag_id: str, session: Session = NEW_SESSION) -> APIRespon for key, value in dag.__dict__.items(): if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) - + try: + dag_detail_schema = DAGDetailSchema(only = fields) if fields else DAGDetailSchema() + except ValueError as e: + raise BadRequest("DAGDetailSchema init error", detail=str(e)) return dag_detail_schema.dump(dag_model) @@ -87,6 +105,7 @@ def get_dags( only_active: bool = True, paused: bool | None = None, order_by: str = "dag_id", + fields: Collection[str] | None = None, session: Session = NEW_SESSION, ) -> APIResponse: """Get all DAGs.""" @@ -113,6 +132,14 @@ def get_dags( dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs) dags = session.scalars(dags_query.offset(offset).limit(limit)).all() + + try: + dags_collection_schema = DAGCollectionSchema( + only=[f"dags.{field}" for field in fields] + ) if fields else DAGCollectionSchema() + except ValueError as e: + raise BadRequest("DAGCollectionSchema init error", detail=str(e)) + return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 45e064764c62e..72e5b2393fdfb 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -17,7 +17,7 @@ from __future__ import annotations from http import HTTPStatus -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Collection import pendulum from connexion import NoContent @@ -41,6 +41,8 @@ ) from airflow.api_connexion.schemas.dag_run_schema import ( DAGRunCollection, + DAGRunSchema, + DAGRunCollectionSchema, clear_dagrun_form_schema, dagrun_collection_schema, dagrun_schema, @@ -91,7 +93,13 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI @security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session -def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dag_run( + *, + dag_id: str, + dag_run_id: str, + fields: Collection[str] | None = None, + session: Session = NEW_SESSION +) -> APIResponse: """Get a DAG Run.""" dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)) if dag_run is None: @@ -99,6 +107,11 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) "DAGRun not found", detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", ) + try: + dagrun_schema = DAGRunSchema(only = fields) if fields else DAGRunSchema() + except ValueError as e: + # Invalid fields + raise BadRequest("DAGRunSchema init error", detail=str(e)) return dagrun_schema.dump(dag_run) @@ -210,6 +223,7 @@ def get_dag_runs( offset: int | None = None, limit: int | None = None, order_by: str = "id", + fields: Collection[str] | None = None, session: Session = NEW_SESSION, ): """Get all DAG Runs.""" @@ -241,6 +255,12 @@ def get_dag_runs( order_by=order_by, session=session, ) + try: + dagrun_collection_schema = DAGRunCollectionSchema( + only=[f"dag_runs.{field}" for field in fields] + ) if fields else DAGRunCollectionSchema() + except ValueError as e: + raise BadRequest("DAGRunCollectionSchema init error", detail=str(e)) return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 60f4c2bc624b3..5c826933349d6 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -421,6 +421,7 @@ paths: - $ref: "#/components/parameters/FilterTags" - $ref: "#/components/parameters/OnlyActive" - $ref: "#/components/parameters/Paused" + - $ref: "#/components/parameters/ReturnFields" - name: dag_id_pattern in: query schema: @@ -497,6 +498,8 @@ paths: x-openapi-router-controller: airflow.api_connexion.endpoints.dag_endpoint operationId: get_dag tags: [DAG] + parameters: + - $ref: "#/components/parameters/ReturnFields" responses: "200": description: Success. @@ -1755,6 +1758,8 @@ paths: The response contains many DAG attributes, so the response can be large. If possible, consider using GET /dags/{dag_id}. tags: [DAG] + parameters: + - $ref: "#/components/parameters/ReturnFields" responses: "200": description: Success. @@ -3571,15 +3576,19 @@ components: properties: timezone: $ref: "#/components/schemas/Timezone" + nullable: true catchup: type: boolean readOnly: true + # nullable: true orientation: type: string readOnly: true + # nullable: true concurrency: type: number readOnly: true + nullable: true start_date: type: string format: "date-time" @@ -3591,6 +3600,7 @@ components: *Changed in version 2.0.1*: Field becomes nullable. dag_run_timeout: $ref: "#/components/schemas/TimeDelta" + nullable: true doc_md: type: string readOnly: true @@ -5253,6 +5263,16 @@ components: style: form explode: false + ReturnFields: + in: query + name: fields + schema: + type: array + items: + type: string + description: | + List of field for return. + # Reusable request bodies requestBodies: {} From 76bde1e2a92efa4d1bb615fbe8122d7a3d52d7bf Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Sun, 7 Jan 2024 15:11:40 +0800 Subject: [PATCH 2/7] pre-commit --- .../api_connexion/endpoints/dag_endpoint.py | 30 ++++++++----------- .../endpoints/dag_run_endpoint.py | 18 +++++------ airflow/www/static/js/types/api-generated.ts | 24 +++++++++++---- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index aa32e24175b33..b178009493339 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -53,30 +53,25 @@ @security.requires_access_dag("GET") @provide_session def get_dag( - *, - dag_id: str, - fields: Collection[str] | None = None, - session: Session = NEW_SESSION + *, dag_id: str, fields: Collection[str] | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get basic information about a DAG.""" dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - print(f"{fields=}") if dag is None: raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found") try: - dag_schema = DAGSchema(only = fields) if fields else DAGSchema() + dag_schema = DAGSchema(only=fields) if fields else DAGSchema() except ValueError as e: raise BadRequest("DAGSchema init error", detail=str(e)) - return dag_schema.dump(dag, ) + return dag_schema.dump( + dag, + ) @security.requires_access_dag("GET") @provide_session def get_dag_details( - *, - dag_id: str, - fields: Collection[str] | None = None, - session: Session = NEW_SESSION + *, dag_id: str, fields: Collection[str] | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -87,7 +82,7 @@ def get_dag_details( if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) try: - dag_detail_schema = DAGDetailSchema(only = fields) if fields else DAGDetailSchema() + dag_detail_schema = DAGDetailSchema(only=fields) if fields else DAGDetailSchema() except ValueError as e: raise BadRequest("DAGDetailSchema init error", detail=str(e)) return dag_detail_schema.dump(dag_model) @@ -132,14 +127,15 @@ def get_dags( dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs) dags = session.scalars(dags_query.offset(offset).limit(limit)).all() - try: - dags_collection_schema = DAGCollectionSchema( - only=[f"dags.{field}" for field in fields] - ) if fields else DAGCollectionSchema() + dags_collection_schema = ( + DAGCollectionSchema(only=[f"dags.{field}" for field in fields]) + if fields + else DAGCollectionSchema() + ) except ValueError as e: raise BadRequest("DAGCollectionSchema init error", detail=str(e)) - + return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 72e5b2393fdfb..1e2cf367d8bd1 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -41,8 +41,8 @@ ) from airflow.api_connexion.schemas.dag_run_schema import ( DAGRunCollection, - DAGRunSchema, DAGRunCollectionSchema, + DAGRunSchema, clear_dagrun_form_schema, dagrun_collection_schema, dagrun_schema, @@ -94,11 +94,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI @security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_run( - *, - dag_id: str, - dag_run_id: str, - fields: Collection[str] | None = None, - session: Session = NEW_SESSION + *, dag_id: str, dag_run_id: str, fields: Collection[str] | None = None, session: Session = NEW_SESSION ) -> APIResponse: """Get a DAG Run.""" dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)) @@ -108,7 +104,7 @@ def get_dag_run( detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", ) try: - dagrun_schema = DAGRunSchema(only = fields) if fields else DAGRunSchema() + dagrun_schema = DAGRunSchema(only=fields) if fields else DAGRunSchema() except ValueError as e: # Invalid fields raise BadRequest("DAGRunSchema init error", detail=str(e)) @@ -256,9 +252,11 @@ def get_dag_runs( session=session, ) try: - dagrun_collection_schema = DAGRunCollectionSchema( - only=[f"dag_runs.{field}" for field in fields] - ) if fields else DAGRunCollectionSchema() + dagrun_collection_schema = ( + DAGRunCollectionSchema(only=[f"dag_runs.{field}" for field in fields]) + if fields + else DAGRunCollectionSchema() + ) except ValueError as e: raise BadRequest("DAGRunCollectionSchema init error", detail=str(e)) return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 55ade6179d3c8..cf644ec9c5e25 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -1470,10 +1470,10 @@ export interface components { * [airflow.models.dag.DAG](https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/models/dag/index.html#airflow.models.dag.DAG) */ DAGDetail: components["schemas"]["DAG"] & { - timezone?: components["schemas"]["Timezone"]; + timezone?: components["schemas"]["Timezone"] | null; catchup?: boolean; orientation?: string; - concurrency?: number; + concurrency?: number | null; /** * Format: date-time * @description The DAG's start date. @@ -1481,7 +1481,7 @@ export interface components { * *Changed in version 2.0.1*: Field becomes nullable. */ start_date?: string | null; - dag_run_timeout?: components["schemas"]["TimeDelta"]; + dag_run_timeout?: components["schemas"]["TimeDelta"] | null; doc_md?: string | null; default_view?: string | null; /** @@ -2472,6 +2472,8 @@ export interface components { * A comma-separated list of fully qualified names of fields. */ UpdateMask: string[]; + /** @description List of field for return. */ + ReturnFields: string[]; }; requestBodies: {}; headers: {}; @@ -2659,6 +2661,8 @@ export interface operations { * *New in version 2.6.0* */ paused?: components["parameters"]["Paused"]; + /** List of field for return. */ + fields?: components["parameters"]["ReturnFields"]; /** If set, only return DAGs with dag_ids matching this pattern. */ dag_id_pattern?: string; }; @@ -2733,6 +2737,10 @@ export interface operations { /** The DAG ID. */ dag_id: components["parameters"]["DAGID"]; }; + query: { + /** List of field for return. */ + fields?: components["parameters"]["ReturnFields"]; + }; }; responses: { /** Success. */ @@ -4098,6 +4106,10 @@ export interface operations { /** The DAG ID. */ dag_id: components["parameters"]["DAGID"]; }; + query: { + /** List of field for return. */ + fields?: components["parameters"]["ReturnFields"]; + }; }; responses: { /** Success. */ @@ -4988,7 +5000,8 @@ export type PatchDagsVariables = CamelCasedPropertiesDeep< operations["patch_dags"]["requestBody"]["content"]["application/json"] >; export type GetDagVariables = CamelCasedPropertiesDeep< - operations["get_dag"]["parameters"]["path"] + operations["get_dag"]["parameters"]["path"] & + operations["get_dag"]["parameters"]["query"] >; export type DeleteDagVariables = CamelCasedPropertiesDeep< operations["delete_dag"]["parameters"]["path"] @@ -5133,7 +5146,8 @@ export type GetLogVariables = CamelCasedPropertiesDeep< operations["get_log"]["parameters"]["query"] >; export type GetDagDetailsVariables = CamelCasedPropertiesDeep< - operations["get_dag_details"]["parameters"]["path"] + operations["get_dag_details"]["parameters"]["path"] & + operations["get_dag_details"]["parameters"]["query"] >; export type GetTasksVariables = CamelCasedPropertiesDeep< operations["get_tasks"]["parameters"]["path"] & From 0ef95e559756987f9de760e01df65b06cdc80496 Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Sun, 7 Jan 2024 15:23:55 +0800 Subject: [PATCH 3/7] swagger yaml --- airflow/api_connexion/openapi/v1.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 5c826933349d6..d0ce84e3470ed 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -736,6 +736,7 @@ paths: - $ref: "#/components/parameters/FilterUpdatedAtLTE" - $ref: "#/components/parameters/FilterState" - $ref: "#/components/parameters/OrderBy" + - $ref: "#/components/parameters/ReturnFields" responses: "200": description: List of DAG runs. @@ -817,6 +818,8 @@ paths: x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint operationId: get_dag_run tags: [DAGRun] + parameters: + - $ref: "#/components/parameters/ReturnFields" responses: "200": description: Success. From 4ed779c1cc0c185b6bfdad0cfcbfa5a2a94fdb30 Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Sun, 7 Jan 2024 16:26:39 +0800 Subject: [PATCH 4/7] pre-commit --- airflow/www/static/js/types/api-generated.ts | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index cf644ec9c5e25..735c9ede3fcb6 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -3005,6 +3005,8 @@ export interface operations { * *New in version 2.1.0* */ order_by?: components["parameters"]["OrderBy"]; + /** List of field for return. */ + fields?: components["parameters"]["ReturnFields"]; }; }; responses: { @@ -3071,6 +3073,10 @@ export interface operations { /** The DAG run ID. */ dag_run_id: components["parameters"]["DAGRunID"]; }; + query: { + /** List of field for return. */ + fields?: components["parameters"]["ReturnFields"]; + }; }; responses: { /** Success. */ @@ -5039,7 +5045,8 @@ export type GetDagRunsBatchVariables = CamelCasedPropertiesDeep< operations["get_dag_runs_batch"]["requestBody"]["content"]["application/json"] >; export type GetDagRunVariables = CamelCasedPropertiesDeep< - operations["get_dag_run"]["parameters"]["path"] + operations["get_dag_run"]["parameters"]["path"] & + operations["get_dag_run"]["parameters"]["query"] >; export type DeleteDagRunVariables = CamelCasedPropertiesDeep< operations["delete_dag_run"]["parameters"]["path"] From 52d05ac18fe65ed4bb0721886cf4ece1f70aa579 Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Sun, 7 Jan 2024 23:32:49 +0800 Subject: [PATCH 5/7] tmp --- .../endpoints/test_dag_endpoint.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index c02e8b0ff3fca..52b4384093b61 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -259,6 +259,52 @@ def test_should_respond_403_with_granular_access_for_different_dag(self): "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 403 + + @pytest.mark.parametrize( + "fields", + [ + [], # empty test + ["dag_id"], # only one + ["fileloc", "file_token", "owners"], # fields.Method and other + ["schedule_interval", "tags"], # fields.List + ], + ) + def test_should_return_specified_fields(self, fields): + self._create_dag_models(1) + response = self.client.get(f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"}) + res_json = response.json + for field in fields: + assert field in res_json + for field in { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "file_token": "Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk", + "is_paused": False, + "is_active": True, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, + "tags": [], + "next_dagrun": None, + "has_task_concurrency_limits": True, + "next_dagrun_data_interval_start": None, + "next_dagrun_data_interval_end": None, + "max_active_runs": 16, + "next_dagrun_create_after": None, + "last_expired": None, + "max_active_tasks": 16, + "last_pickled": None, + "default_view": None, + "last_parsed_time": None, + "scheduler_lock": None, + "timetable_description": None, + "has_import_errors": False, + "pickle_id": None, + }.keys(): + if field not in fields: + assert field not in res_json class TestGetDagDetails(TestDagEndpoint): From 37fabb7f7a8b322a2c9ff0efd696683cd20e33ee Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Wed, 10 Jan 2024 11:32:33 +0800 Subject: [PATCH 6/7] add tests to dag/dagrun return fields --- .../api_connexion/endpoints/dag_endpoint.py | 5 +- .../endpoints/dag_run_endpoint.py | 15 ++-- airflow/api_connexion/openapi/v1.yaml | 4 +- .../api_connexion/schemas/dag_run_schema.py | 12 ++- airflow/www/static/js/types/api-generated.ts | 4 +- .../endpoints/test_dag_endpoint.py | 83 +++++++++++++++++-- .../endpoints/test_dag_run_endpoint.py | 83 +++++++++++++++++++ 7 files changed, 184 insertions(+), 22 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index b178009493339..9ecdf290ac5c8 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -133,10 +133,9 @@ def get_dags( if fields else DAGCollectionSchema() ) + return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) except ValueError as e: - raise BadRequest("DAGCollectionSchema init error", detail=str(e)) - - return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) + raise BadRequest("DAGCollectionSchema error", detail=str(e)) @security.requires_access_dag("PUT") diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 1e2cf367d8bd1..149af771e4a7c 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -104,11 +104,12 @@ def get_dag_run( detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found", ) try: - dagrun_schema = DAGRunSchema(only=fields) if fields else DAGRunSchema() + # parse fields to Schema @post_dump + dagrun_schema = DAGRunSchema(context={"fields": fields}) if fields else DAGRunSchema() + return dagrun_schema.dump(dag_run) except ValueError as e: # Invalid fields - raise BadRequest("DAGRunSchema init error", detail=str(e)) - return dagrun_schema.dump(dag_run) + raise BadRequest("DAGRunSchema error", detail=str(e)) @security.requires_access_dag("GET", DagAccessEntity.RUN) @@ -253,13 +254,11 @@ def get_dag_runs( ) try: dagrun_collection_schema = ( - DAGRunCollectionSchema(only=[f"dag_runs.{field}" for field in fields]) - if fields - else DAGRunCollectionSchema() + DAGRunCollectionSchema(context={"fields": fields}) if fields else DAGRunCollectionSchema() ) + return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) except ValueError as e: - raise BadRequest("DAGRunCollectionSchema init error", detail=str(e)) - return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) + raise BadRequest("DAGRunCollectionSchema error", detail=str(e)) @security.requires_access_dag("GET", DagAccessEntity.RUN) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index d0ce84e3470ed..80b11f3293fa0 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -3583,11 +3583,11 @@ components: catchup: type: boolean readOnly: true - # nullable: true + nullable: true orientation: type: string readOnly: true - # nullable: true + nullable: true concurrency: type: number readOnly: true diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index d9cecb9b0b137..da01751f5969e 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -108,8 +108,18 @@ def autogenerate(self, data, **kwargs): @post_dump def autofill(self, data, **kwargs): """Populate execution_date from logical_date for compatibility.""" + ret_data = {} data["execution_date"] = data["logical_date"] - return data + if self.context.get("fields"): + ret_fields = self.context.get("fields") + for ret_field in ret_fields: + if ret_field not in data: + raise ValueError(f"{ret_field} not in DAGRunSchema") + ret_data[ret_field] = data[ret_field] + else: + ret_data = data + + return ret_data class SetDagRunStateFormSchema(Schema): diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 735c9ede3fcb6..9888deae84cba 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -1471,8 +1471,8 @@ export interface components { */ DAGDetail: components["schemas"]["DAG"] & { timezone?: components["schemas"]["Timezone"] | null; - catchup?: boolean; - orientation?: string; + catchup?: boolean | null; + orientation?: string | null; concurrency?: number | null; /** * Format: date-time diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 52b4384093b61..ec11984d3527d 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -259,19 +259,20 @@ def test_should_respond_403_with_granular_access_for_different_dag(self): "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 403 - + @pytest.mark.parametrize( "fields", [ - [], # empty test - ["dag_id"], # only one - ["fileloc", "file_token", "owners"], # fields.Method and other - ["schedule_interval", "tags"], # fields.List + ["dag_id"], # only one + ["fileloc", "file_token", "owners"], # auto_field and fields.Method + ["schedule_interval", "tags"], # fields.List ], ) def test_should_return_specified_fields(self, fields): self._create_dag_models(1) - response = self.client.get(f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get( + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + ) res_json = response.json for field in fields: assert field in res_json @@ -306,6 +307,21 @@ def test_should_return_specified_fields(self, fields): if field not in fields: assert field not in res_json + @pytest.mark.parametrize( + "fields", + [ + [], # empty test + ["#caw&c"], # field which not exists + ["dag_id", "#caw&c"], # field which not exists + ], + ) + def test_should_respond_400_with_not_exists_fields(self, fields): + self._create_dag_models(1) + response = self.client.get( + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + ) + assert response.status_code == 400, f"Current code: {response.status_code}" + class TestGetDagDetails(TestDagEndpoint): def test_should_respond_200(self, url_safe_serializer): @@ -607,6 +623,35 @@ def test_should_raise_404_when_dag_is_not_found(self): "type": EXCEPTIONS_LINK_MAP[404], } + @pytest.mark.parametrize( + "fields", + [ + ["dag_id"], # only one + ["doc_md", "file_token", "owners"], # fields.String and fields.Method + ["schedule_interval", "tags"], # fields.List + ], + ) + def test_should_return_specified_fields(self, fields): + self._create_dag_model_for_details_endpoint(self.dag2_id) + response = self.client.get( + f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + res_json = response.json + assert len(res_json.keys()) == len(fields) + for field in fields: + assert field in res_json + + def test_should_respond_400_with_not_exists_fields(self): + fields = ["#caw&c"] + self._create_dag_model_for_details_endpoint(self.dag2_id) + response = self.client.get( + f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400, f"Current code: {response.status_code}" + class TestGetDags(TestDagEndpoint): @provide_session @@ -1094,6 +1139,32 @@ def test_paused_none_returns_all_dags(self, url_safe_serializer): "total_entries": 2, } == response.json + def test_should_return_specified_fields(self): + self._create_dag_models(2) + self._create_deactivated_dag() + + fields = ["dag_id", "file_token", "owners"] + response = self.client.get( + f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + ) + assert response.status_code == 200 + + res_json = response.json + for dag in res_json["dags"]: + assert len(dag.keys()) == len(fields) + for field in fields: + assert field in dag + + def test_should_respond_400_with_not_exists_fields(self): + self._create_dag_models(1) + self._create_deactivated_dag() + fields = ["#caw&c"] + response = self.client.get( + f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + ) + + assert response.status_code == 400, f"Current code: {response.status_code}" + class TestPatchDag(TestDagEndpoint): def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer): diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 2c4c393dd3022..0ce3f222db446 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -280,6 +280,59 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) + @pytest.mark.parametrize( + "fields", + [ + ["dag_run_id", "logical_date"], + ["dag_run_id", "state", "conf", "execution_date"], + ], + ) + def test_should_return_specified_fields(self, session, fields): + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id="TEST_DAG_RUN_ID", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state="running", + ) + session.add(dagrun_model) + session.commit() + result = session.query(DagRun).all() + assert len(result) == 1 + response = self.client.get( + f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + res_json = response.json + print("get dagRun", res_json) + assert len(res_json.keys()) == len(fields) + for field in fields: + assert field in res_json + + def test_should_respond_400_with_not_exists_fields(self, session): + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id="TEST_DAG_RUN_ID", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state="running", + ) + session.add(dagrun_model) + session.commit() + result = session.query(DagRun).all() + assert len(result) == 1 + fields = ["#caw&c"] + response = self.client.get( + f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400, f"Current code: {response.status_code}" + class TestGetDagRuns(TestDagRunEndpoint): def test_should_respond_200(self, session): @@ -425,6 +478,36 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) + @pytest.mark.parametrize( + "fields", + [ + ["dag_run_id", "logical_date"], + ["dag_run_id", "state", "conf", "execution_date"], + ], + ) + def test_should_return_specified_fields(self, session, fields): + self._create_test_dag_run() + result = session.query(DagRun).all() + assert len(result) == 2 + response = self.client.get( + f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + for dag_run in response.json["dag_runs"]: + assert len(dag_run.keys()) == len(fields) + for field in fields: + assert field in dag_run + + def test_should_respond_400_with_not_exists_fields(self): + self._create_test_dag_run() + fields = ["#caw&c"] + response = self.client.get( + f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400, f"Current code: {response.status_code}" + class TestGetDagRunsPagination(TestDagRunEndpoint): @pytest.mark.parametrize( From 98676d485d57dcfcabb8dfded5a236c1a4bd4433 Mon Sep 17 00:00:00 2001 From: HarryWu-CHN <904714159@qq.com> Date: Wed, 10 Jan 2024 12:11:53 +0800 Subject: [PATCH 7/7] fix --- .../endpoints/test_dag_endpoint.py | 31 +------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index ec11984d3527d..86a9d4e474b25 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -274,38 +274,9 @@ def test_should_return_specified_fields(self, fields): f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} ) res_json = response.json + assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json - for field in { - "dag_id": "TEST_DAG_1", - "description": None, - "fileloc": "/tmp/dag_1.py", - "file_token": "Ii90bXAvZGFnXzEucHki.EnmIdPaUPo26lHQClbWMbDFD1Pk", - "is_paused": False, - "is_active": True, - "is_subdag": False, - "owners": [], - "root_dag_id": None, - "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, - "tags": [], - "next_dagrun": None, - "has_task_concurrency_limits": True, - "next_dagrun_data_interval_start": None, - "next_dagrun_data_interval_end": None, - "max_active_runs": 16, - "next_dagrun_create_after": None, - "last_expired": None, - "max_active_tasks": 16, - "last_pickled": None, - "default_view": None, - "last_parsed_time": None, - "scheduler_lock": None, - "timetable_description": None, - "has_import_errors": False, - "pickle_id": None, - }.keys(): - if field not in fields: - assert field not in res_json @pytest.mark.parametrize( "fields",