From 0c71786d1091be28edbc0392d42a653837c5e101 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 16 Nov 2024 22:11:45 +0800 Subject: [PATCH 1/9] Add bulk post pools, refactor post pool --- .../api_fastapi/core_api/datamodels/pools.py | 23 ++++++- .../core_api/openapi/v1-generated.yaml | 66 +++++++++++++++++++ .../core_api/routes/public/pools.py | 42 ++++++++++++ airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 38 +++++++++++ .../ui/openapi-gen/requests/schemas.gen.ts | 17 +++++ .../ui/openapi-gen/requests/services.gen.ts | 29 ++++++++ airflow/ui/openapi-gen/requests/types.gen.ts | 48 ++++++++++++++ 8 files changed, 264 insertions(+), 2 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index ef3676a8afec7..1cc3838b23912 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,7 @@ from typing import Annotated, Callable -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator def _call_function(function: Callable[[], int]) -> int: @@ -72,6 +72,25 @@ class PoolPatchBody(BaseModel): class PoolPostBody(BasePool): """Pool serializer for post bodies.""" - pool: str = Field(alias="name") + pool: str = Field(alias="name", max_length=256) description: str | None = None include_deferred: bool = False + + +class PoolPostBulkBody(BaseModel): + """Pools serializer for post bodies.""" + + pools: list[PoolPostBody] + + @field_validator("pools", mode="after") + def validate_pools(cls, v: list[PoolPostBody]) -> list[PoolPostBody]: + pool_set = set() + duplicates = [] + for pool in v: + if pool.pool in pool_set: + duplicates.append(pool.pool) + else: + pool_set.add(pool.pool) + if duplicates: + raise ValueError(f"Pool name should be unique, found duplicates: {duplicates}") + return v diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 2f74f2268928f..0b9ce1b592685 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3278,6 +3278,59 @@ paths: schema: $ref: '#/components/schemas/HTTPExceptionResponse' description: Forbidden + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /public/pools/bulk: + post: + tags: + - Pool + summary: Post Pools + description: Create multiple pools. + operationId: post_pools + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PoolPostBulkBody' + required: true + responses: + '201': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/PoolCollectionResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '400': + description: Validation error + example: {} + '409': + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' '422': description: Validation Error content: @@ -6544,6 +6597,7 @@ components: properties: name: type: string + maxLength: 256 title: Name slots: type: integer @@ -6563,6 +6617,18 @@ components: - slots title: PoolPostBody description: Pool serializer for post bodies. + PoolPostBulkBody: + properties: + pools: + items: + $ref: '#/components/schemas/PoolPostBody' + type: array + title: Pools + type: object + required: + - pools + title: PoolPostBulkBody + description: Pools serializer for post bodies. PoolResponse: properties: name: diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 582e03ab00dbd..2d8a64fb09a3a 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -32,6 +32,7 @@ PoolCollectionResponse, PoolPatchBody, PoolPostBody, + PoolPostBulkBody, PoolResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc @@ -160,14 +161,55 @@ def patch_pool( @pools_router.post( "", status_code=status.HTTP_201_CREATED, + responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]), ) def post_pool( post_body: PoolPostBody, session: Annotated[Session, Depends(get_session)], ) -> PoolResponse: """Create a Pool.""" + pool = session.scalar(select(Pool).where(Pool.pool == post_body.pool)) + if pool is not None: + raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{post_body.pool}` already exists") pool = Pool(**post_body.model_dump()) session.add(pool) return PoolResponse.model_validate(pool, from_attributes=True) + + +@pools_router.post( + "/bulk", + status_code=status.HTTP_201_CREATED, + responses={ + status.HTTP_400_BAD_REQUEST: {"description": "Validation error", "example": {}}, + **create_openapi_http_exception_doc( + [ + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + status.HTTP_409_CONFLICT, + ] + ), + }, +) +def post_pools( + post_bulk_body: PoolPostBulkBody, + session: Annotated[Session, Depends(get_session)], +) -> PoolCollectionResponse: + """Create multiple pools.""" + # Check if any of the pools already exists + pools_names = [pool.pool for pool in post_bulk_body.pools] + existing_pools = session.scalars(select(Pool.pool).where(Pool.pool.in_(pools_names))).all() + if existing_pools: + raise HTTPException( + status.HTTP_409_CONFLICT, + detail=f"Pools with names: `{existing_pools}` already exist", + ) + + pools = [Pool(**post_body.model_dump()) for post_body in post_bulk_body.pools] + session.add_all(pools) + + return PoolCollectionResponse( + pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], + total_entries=len(pools), + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 736425218022a..968c0617caf80 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1330,6 +1330,9 @@ export type DagRunServiceClearDagRunMutationResult = Awaited< export type PoolServicePostPoolMutationResult = Awaited< ReturnType >; +export type PoolServicePostPoolsMutationResult = Awaited< + ReturnType +>; export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 2ca159f465f51..b2343d0add10c 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -41,6 +41,7 @@ import { DagWarningType, PoolPatchBody, PoolPostBody, + PoolPostBulkBody, TaskInstancesBatchBody, VariableBody, } from "../requests/types.gen"; @@ -2363,6 +2364,43 @@ export const usePoolServicePostPool = < PoolService.postPool({ requestBody }) as unknown as Promise, ...options, }); +/** + * Post Pools + * Create multiple pools. + * @param data The data for the request. + * @param data.requestBody + * @returns PoolCollectionResponse Successful Response + * @throws ApiError + */ +export const usePoolServicePostPools = < + TData = Common.PoolServicePostPoolsMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + requestBody: PoolPostBulkBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + requestBody: PoolPostBulkBody; + }, + TContext + >({ + mutationFn: ({ requestBody }) => + PoolService.postPools({ requestBody }) as unknown as Promise, + ...options, + }); /** * Get Task Instances Batch * Get list of task instances. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index a0bb85ace80ad..3d51b72eba2e5 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -2848,6 +2848,7 @@ export const $PoolPostBody = { properties: { name: { type: "string", + maxLength: 256, title: "Name", }, slots: { @@ -2877,6 +2878,22 @@ export const $PoolPostBody = { description: "Pool serializer for post bodies.", } as const; +export const $PoolPostBulkBody = { + properties: { + pools: { + items: { + $ref: "#/components/schemas/PoolPostBody", + }, + type: "array", + title: "Pools", + }, + }, + type: "object", + required: ["pools"], + title: "PoolPostBulkBody", + description: "Pools serializer for post bodies.", +} as const; + export const $PoolResponse = { properties: { name: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 63272c5e0c470..5ca3efe12c0da 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -109,6 +109,8 @@ import type { GetPoolsResponse, PostPoolData, PostPoolResponse, + PostPoolsData, + PostPoolsResponse, GetProvidersData, GetProvidersResponse, GetTaskInstanceData, @@ -1790,6 +1792,33 @@ export class PoolService { errors: { 401: "Unauthorized", 403: "Forbidden", + 409: "Conflict", + 422: "Validation Error", + }, + }); + } + + /** + * Post Pools + * Create multiple pools. + * @param data The data for the request. + * @param data.requestBody + * @returns PoolCollectionResponse Successful Response + * @throws ApiError + */ + public static postPools( + data: PostPoolsData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/pools/bulk", + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Validation error", + 401: "Unauthorized", + 403: "Forbidden", + 409: "Conflict", 422: "Validation Error", }, }); diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 926932379b350..e325cb404af51 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -712,6 +712,13 @@ export type PoolPostBody = { include_deferred?: boolean; }; +/** + * Pools serializer for post bodies. + */ +export type PoolPostBulkBody = { + pools: Array; +}; + /** * Pool serializer for responses. */ @@ -1510,6 +1517,12 @@ export type PostPoolData = { export type PostPoolResponse = PoolResponse; +export type PostPoolsData = { + requestBody: PoolPostBulkBody; +}; + +export type PostPoolsResponse = PoolCollectionResponse; + export type GetProvidersData = { limit?: number; offset?: number; @@ -3107,6 +3120,41 @@ export type $OpenApiTs = { * Forbidden */ 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; + "/public/pools/bulk": { + post: { + req: PostPoolsData; + res: { + /** + * Successful Response + */ + 201: PoolCollectionResponse; + /** + * Validation error + */ + 400: unknown; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; /** * Validation Error */ From 86e324c9f2275b07edca2f81a30badd6bbd068d5 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 16 Nov 2024 22:37:32 +0800 Subject: [PATCH 2/9] Add 409 case for TestPostPool --- .../core_api/routes/public/test_pools.py | 77 +++++++++++++++++-- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 4a774f1a1e379..9996b29f83311 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -54,6 +54,28 @@ def teardown_method(self) -> None: def create_pools(self): _create_pools() + def _create_pool_in_test( + self, + test_client, + session, + body, + expected_status_code, + expected_response, + create_default=True, + check_count=True, + ): + if create_default: + self.create_pools() + if check_count: + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/", json=body) + assert response.status_code == expected_status_code + + body = response.json() + assert response.json() == expected_response + if check_count: + assert session.query(Pool).count() == n_pools + 1 + class TestDeletePool(TestPoolsEndpoint): def test_delete_should_respond_204(self, test_client, session): @@ -322,11 +344,52 @@ class TestPostPool(TestPoolsEndpoint): ], ) def test_should_respond_200(self, test_client, session, body, expected_status_code, expected_response): - self.create_pools() - n_pools = session.query(Pool).count() - response = test_client.post("/public/pools", json=body) - assert response.status_code == expected_status_code + self._create_pool_in_test(test_client, session, body, expected_status_code, expected_response) - body = response.json() - assert response.json() == expected_response - assert session.query(Pool).count() == n_pools + 1 + @pytest.mark.parametrize( + "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", + [ + ( + {"name": "my_pool", "slots": 11}, + 201, + { + "name": "my_pool", + "slots": 11, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 11, + "deferred_slots": 0, + }, + 409, + {"detail": "Pool with name: `my_pool` already exists"}, + ), + ], + ) + def test_should_response_409( + self, + test_client, + session, + body, + first_expected_status_code, + first_expected_response, + second_expected_status_code, + second_expected_response, + ): + # first request + self._create_pool_in_test( + test_client, session, body, first_expected_status_code, first_expected_response + ) + # second request + self._create_pool_in_test( + test_client, + session, + body, + second_expected_status_code, + second_expected_response, + create_default=False, + check_count=False, + ) From 64439896322112351085931212c90e87dbec926d Mon Sep 17 00:00:00 2001 From: jason810496 Date: Mon, 18 Nov 2024 10:33:50 +0800 Subject: [PATCH 3/9] Add test for bulk post pools --- .../core_api/routes/public/test_pools.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 9996b29f83311..eb56d50b7e33a 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -393,3 +393,103 @@ def test_should_response_409( create_default=False, check_count=False, ) + + +class TestPostPools(TestPoolsEndpoint): + @pytest.mark.parametrize( + "body, expected_status_code, expected_response", + [ + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": "my_pool2", "slots": 12}, + ] + }, + 201, + { + "pools": [ + { + "name": "my_pool", + "slots": 11, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 11, + "deferred_slots": 0, + }, + { + "name": "my_pool2", + "slots": 12, + "description": None, + "include_deferred": False, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "open_slots": 12, + "deferred_slots": 0, + }, + ], + "total_entries": 2, + }, + ), + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": POOL1_NAME, "slots": 12}, + ] + }, + 409, + {}, + ), + ( + { + "pools": [ + {"name": POOL1_NAME, "slots": 11}, + {"name": POOL2_NAME, "slots": 12}, + ] + }, + 409, + {}, + ), + ( + { + "pools": [ + {"name": "my_pool", "slots": 11}, + {"name": "my_pool", "slots": 12}, + ] + }, + 422, + { + "detail": [ + { + "loc": ["body", "pools"], + "msg": "Value error, Pool name should be unique, found duplicates: ['my_pool']", + "type": "value_error", + } + ] + }, + ), + ], + ) + def test_post_pools(self, test_client, session, body, expected_status_code, expected_response): + self.create_pools() + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/bulk", json=body) + assert response.status_code == expected_status_code + response_json = response.json() + if expected_status_code == 201: + assert response_json == expected_response + elif expected_status_code == 422: + assert response_json["detail"][0]["loc"] == expected_response["detail"][0]["loc"] + assert response_json["detail"][0]["msg"] == expected_response["detail"][0]["msg"] + assert response_json["detail"][0]["type"] == expected_response["detail"][0]["type"] + if expected_status_code == 201: + assert session.query(Pool).count() == n_pools + 2 + else: + assert session.query(Pool).count() == n_pools From 01ba623baff61bb93fa9c2e6769db895f491cd89 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Mon, 18 Nov 2024 20:45:32 +0800 Subject: [PATCH 4/9] Remove unused status code, rename post_body to body --- .../core_api/openapi/v1-generated.yaml | 3 -- .../core_api/routes/public/pools.py | 29 ++++++++----------- .../ui/openapi-gen/requests/services.gen.ts | 1 - airflow/ui/openapi-gen/requests/types.gen.ts | 4 --- 4 files changed, 12 insertions(+), 25 deletions(-) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 0b9ce1b592685..395a6acea5f6e 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3322,9 +3322,6 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPExceptionResponse' - '400': - description: Validation error - example: {} '409': description: Conflict content: diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 2d8a64fb09a3a..a812d6d52d862 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -164,14 +164,14 @@ def patch_pool( responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]), ) def post_pool( - post_body: PoolPostBody, + body: PoolPostBody, session: Annotated[Session, Depends(get_session)], ) -> PoolResponse: """Create a Pool.""" - pool = session.scalar(select(Pool).where(Pool.pool == post_body.pool)) + pool = session.scalar(select(Pool).where(Pool.pool == body.pool)) if pool is not None: - raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{post_body.pool}` already exists") - pool = Pool(**post_body.model_dump()) + raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") + pool = Pool(**body.model_dump()) session.add(pool) @@ -181,24 +181,19 @@ def post_pool( @pools_router.post( "/bulk", status_code=status.HTTP_201_CREATED, - responses={ - status.HTTP_400_BAD_REQUEST: {"description": "Validation error", "example": {}}, - **create_openapi_http_exception_doc( - [ - status.HTTP_401_UNAUTHORIZED, - status.HTTP_403_FORBIDDEN, - status.HTTP_409_CONFLICT, - ] - ), - }, + responses=create_openapi_http_exception_doc( + [ + status.HTTP_409_CONFLICT, + ] + ), ) def post_pools( - post_bulk_body: PoolPostBulkBody, + body: PoolPostBulkBody, session: Annotated[Session, Depends(get_session)], ) -> PoolCollectionResponse: """Create multiple pools.""" # Check if any of the pools already exists - pools_names = [pool.pool for pool in post_bulk_body.pools] + pools_names = [pool.pool for pool in body.pools] existing_pools = session.scalars(select(Pool.pool).where(Pool.pool.in_(pools_names))).all() if existing_pools: raise HTTPException( @@ -206,7 +201,7 @@ def post_pools( detail=f"Pools with names: `{existing_pools}` already exist", ) - pools = [Pool(**post_body.model_dump()) for post_body in post_bulk_body.pools] + pools = [Pool(**body.model_dump()) for body in body.pools] session.add_all(pools) return PoolCollectionResponse( diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 5ca3efe12c0da..32338dfbf6d80 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -1815,7 +1815,6 @@ export class PoolService { body: data.requestBody, mediaType: "application/json", errors: { - 400: "Validation error", 401: "Unauthorized", 403: "Forbidden", 409: "Conflict", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index e325cb404af51..73772b38502b7 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -3139,10 +3139,6 @@ export type $OpenApiTs = { * Successful Response */ 201: PoolCollectionResponse; - /** - * Validation error - */ - 400: unknown; /** * Unauthorized */ From 6d496189548112ef072ac53a87d821ec60840955 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Wed, 20 Nov 2024 20:16:22 +0800 Subject: [PATCH 5/9] Refactor duplicate pool insert handling - handle exception from db level instead of application level --- .../api_fastapi/core_api/datamodels/pools.py | 15 +----------- .../core_api/routes/public/pools.py | 24 ++++++++++--------- .../core_api/routes/public/test_pools.py | 19 +++------------ 3 files changed, 17 insertions(+), 41 deletions(-) diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index 1cc3838b23912..137392094cb5d 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,7 @@ from typing import Annotated, Callable -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field def _call_function(function: Callable[[], int]) -> int: @@ -81,16 +81,3 @@ class PoolPostBulkBody(BaseModel): """Pools serializer for post bodies.""" pools: list[PoolPostBody] - - @field_validator("pools", mode="after") - def validate_pools(cls, v: list[PoolPostBody]) -> list[PoolPostBody]: - pool_set = set() - duplicates = [] - for pool in v: - if pool.pool in pool_set: - duplicates.append(pool.pool) - else: - pool_set.add(pool.pool) - if duplicates: - raise ValueError(f"Pool name should be unique, found duplicates: {duplicates}") - return v diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index a812d6d52d862..9e28947962f4d 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -22,6 +22,7 @@ from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import delete, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_fastapi.common.db.common import get_session, paginated_select @@ -168,12 +169,14 @@ def post_pool( session: Annotated[Session, Depends(get_session)], ) -> PoolResponse: """Create a Pool.""" - pool = session.scalar(select(Pool).where(Pool.pool == body.pool)) - if pool is not None: - raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") pool = Pool(**body.model_dump()) session.add(pool) + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") return PoolResponse.model_validate(pool, from_attributes=True) @@ -192,18 +195,17 @@ def post_pools( session: Annotated[Session, Depends(get_session)], ) -> PoolCollectionResponse: """Create multiple pools.""" - # Check if any of the pools already exists - pools_names = [pool.pool for pool in body.pools] - existing_pools = session.scalars(select(Pool.pool).where(Pool.pool.in_(pools_names))).all() - if existing_pools: + pools = [Pool(**body.model_dump()) for body in body.pools] + session.add_all(pools) + try: + session.commit() + except IntegrityError as e: + session.rollback() raise HTTPException( status.HTTP_409_CONFLICT, - detail=f"Pools with names: `{existing_pools}` already exist", + detail=f"One or more pools already exists. Error: {e}", ) - pools = [Pool(**body.model_dump()) for body in body.pools] - session.add_all(pools) - return PoolCollectionResponse( pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], total_entries=len(pools), diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index eb56d50b7e33a..c74421a900f41 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -71,7 +71,6 @@ def _create_pool_in_test( response = test_client.post("/public/pools/", json=body) assert response.status_code == expected_status_code - body = response.json() assert response.json() == expected_response if check_count: assert session.query(Pool).count() == n_pools + 1 @@ -464,16 +463,8 @@ class TestPostPools(TestPoolsEndpoint): {"name": "my_pool", "slots": 12}, ] }, - 422, - { - "detail": [ - { - "loc": ["body", "pools"], - "msg": "Value error, Pool name should be unique, found duplicates: ['my_pool']", - "type": "value_error", - } - ] - }, + 409, + {}, ), ], ) @@ -485,11 +476,7 @@ def test_post_pools(self, test_client, session, body, expected_status_code, expe response_json = response.json() if expected_status_code == 201: assert response_json == expected_response - elif expected_status_code == 422: - assert response_json["detail"][0]["loc"] == expected_response["detail"][0]["loc"] - assert response_json["detail"][0]["msg"] == expected_response["detail"][0]["msg"] - assert response_json["detail"][0]["type"] == expected_response["detail"][0]["type"] - if expected_status_code == 201: assert session.query(Pool).count() == n_pools + 2 else: + # since different database backend return different error messages, we just check the status code assert session.query(Pool).count() == n_pools From dcd81d3d83dd21799c4ff8ea050631054348842e Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 21 Nov 2024 15:22:12 +0800 Subject: [PATCH 6/9] Add global database exception handler for fastapi --- airflow/api_fastapi/app.py | 9 +++- airflow/api_fastapi/common/exceptions.py | 64 ++++++++++++++++++++++++ airflow/api_fastapi/core_api/app.py | 8 +++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 airflow/api_fastapi/common/exceptions.py diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 4bf6ae9f6b77c..02841c8c211f5 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -22,7 +22,13 @@ from fastapi import FastAPI from starlette.routing import Mount -from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views +from airflow.api_fastapi.core_api.app import ( + init_config, + init_dag_bag, + init_error_handlers, + init_plugins, + init_views, +) from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.configuration import conf @@ -61,6 +67,7 @@ def create_app(apps: str = "all") -> FastAPI: init_dag_bag(app) init_views(app) init_plugins(app) + init_error_handlers(app) init_auth_manager() if "execution" in apps_list or "all" in apps_list: diff --git a/airflow/api_fastapi/common/exceptions.py b/airflow/api_fastapi/common/exceptions.py new file mode 100644 index 0000000000000..98d0c907ee732 --- /dev/null +++ b/airflow/api_fastapi/common/exceptions.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from fastapi import HTTPException, Request, status +from sqlalchemy.exc import IntegrityError + +T = TypeVar("T") + + +class BaseErrorHandler(Generic[T], ABC): + """Base class for error handlers.""" + + def __init__(self, exception_cls: T) -> None: + self.exception_cls = exception_cls + + @abstractmethod + def exception_handler(self, request: Request, exc: T): + """exception_handler method.""" + raise NotImplementedError + + +class _UniqueConstraintErrorHandler(BaseErrorHandler[IntegrityError]): + """Exception raised when trying to insert a duplicate value in a unique column.""" + + def __init__(self): + super().__init__(IntegrityError) + self.unique_constraint_error_messages = [ + "UNIQUE constraint failed", # SQLite + "Duplicate entry", # MySQL + "violates unique constraint", # PostgreSQL + ] + + def exception_handler(self, request: Request, exc: IntegrityError): + """Handle IntegrityError exception.""" + exc_orig_str = str(exc.orig) + if any(error_msg in exc_orig_str for error_msg in self.unique_constraint_error_messages): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Unique constraint violation", + ) + + +DatabaseErrorHandlers = [ + _UniqueConstraintErrorHandler(), +] diff --git a/airflow/api_fastapi/core_api/app.py b/airflow/api_fastapi/core_api/app.py index fc29e51f999fe..0e0a375054b19 100644 --- a/airflow/api_fastapi/core_api/app.py +++ b/airflow/api_fastapi/core_api/app.py @@ -120,3 +120,11 @@ def init_config(app: FastAPI) -> None: app.add_middleware(GZipMiddleware, minimum_size=1024, compresslevel=5) app.state.secret_key = conf.get("webserver", "secret_key") + + +def init_error_handlers(app: FastAPI) -> None: + from airflow.api_fastapi.common.exceptions import DatabaseErrorHandlers + + # register database error handlers + for handler in DatabaseErrorHandlers: + app.add_exception_handler(handler.exception_cls, handler.exception_handler) From 531416b3a8a259cd46fdead279b0dfd1d7e5e1b5 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 21 Nov 2024 15:31:17 +0800 Subject: [PATCH 7/9] Remove manual handle for unique constraint exc --- .../core_api/routes/public/pools.py | 22 ++++--------------- .../core_api/routes/public/test_pools.py | 2 +- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 9e28947962f4d..ad2ba40ae2fca 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -22,7 +22,6 @@ from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import delete, select -from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from airflow.api_fastapi.common.db.common import get_session, paginated_select @@ -162,7 +161,9 @@ def patch_pool( @pools_router.post( "", status_code=status.HTTP_201_CREATED, - responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]), + responses=create_openapi_http_exception_doc( + [status.HTTP_409_CONFLICT] + ), # handle by global exception handler ) def post_pool( body: PoolPostBody, @@ -170,13 +171,7 @@ def post_pool( ) -> PoolResponse: """Create a Pool.""" pool = Pool(**body.model_dump()) - session.add(pool) - try: - session.commit() - except IntegrityError: - session.rollback() - raise HTTPException(status.HTTP_409_CONFLICT, f"Pool with name: `{body.pool}` already exists") return PoolResponse.model_validate(pool, from_attributes=True) @@ -186,7 +181,7 @@ def post_pool( status_code=status.HTTP_201_CREATED, responses=create_openapi_http_exception_doc( [ - status.HTTP_409_CONFLICT, + status.HTTP_409_CONFLICT, # handle by global exception handler ] ), ) @@ -197,15 +192,6 @@ def post_pools( """Create multiple pools.""" pools = [Pool(**body.model_dump()) for body in body.pools] session.add_all(pools) - try: - session.commit() - except IntegrityError as e: - session.rollback() - raise HTTPException( - status.HTTP_409_CONFLICT, - detail=f"One or more pools already exists. Error: {e}", - ) - return PoolCollectionResponse( pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], total_entries=len(pools), diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index c74421a900f41..2dbc59bcc70f1 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -364,7 +364,7 @@ def test_should_respond_200(self, test_client, session, body, expected_status_co "deferred_slots": 0, }, 409, - {"detail": "Pool with name: `my_pool` already exists"}, + {"detail": "Unique constraint violation"}, ), ], ) From 955acb865ac65284304dcb0303aaa1518861fd6d Mon Sep 17 00:00:00 2001 From: jason810496 Date: Thu, 21 Nov 2024 15:38:48 +0800 Subject: [PATCH 8/9] Refactor test_pools --- .../core_api/routes/public/test_pools.py | 63 +++++++------------ 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 2dbc59bcc70f1..1cbc62406636b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -54,27 +54,6 @@ def teardown_method(self) -> None: def create_pools(self): _create_pools() - def _create_pool_in_test( - self, - test_client, - session, - body, - expected_status_code, - expected_response, - create_default=True, - check_count=True, - ): - if create_default: - self.create_pools() - if check_count: - n_pools = session.query(Pool).count() - response = test_client.post("/public/pools/", json=body) - assert response.status_code == expected_status_code - - assert response.json() == expected_response - if check_count: - assert session.query(Pool).count() == n_pools + 1 - class TestDeletePool(TestPoolsEndpoint): def test_delete_should_respond_204(self, test_client, session): @@ -343,7 +322,13 @@ class TestPostPool(TestPoolsEndpoint): ], ) def test_should_respond_200(self, test_client, session, body, expected_status_code, expected_response): - self._create_pool_in_test(test_client, session, body, expected_status_code, expected_response) + self.create_pools() + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/", json=body) + assert response.status_code == expected_status_code + + assert response.json() == expected_response + assert session.query(Pool).count() == n_pools + 1 @pytest.mark.parametrize( "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", @@ -378,20 +363,16 @@ def test_should_response_409( second_expected_status_code, second_expected_response, ): - # first request - self._create_pool_in_test( - test_client, session, body, first_expected_status_code, first_expected_response - ) - # second request - self._create_pool_in_test( - test_client, - session, - body, - second_expected_status_code, - second_expected_response, - create_default=False, - check_count=False, - ) + self.create_pools() + n_pools = session.query(Pool).count() + response = test_client.post("/public/pools/", json=body) + assert response.status_code == first_expected_status_code + assert response.json() == first_expected_response + assert session.query(Pool).count() == n_pools + 1 + response = test_client.post("/public/pools/", json=body) + assert response.status_code == second_expected_status_code + assert response.json() == second_expected_response + assert session.query(Pool).count() == n_pools + 1 class TestPostPools(TestPoolsEndpoint): @@ -444,7 +425,7 @@ class TestPostPools(TestPoolsEndpoint): ] }, 409, - {}, + {"detail": "Unique constraint violation"}, ), ( { @@ -454,7 +435,7 @@ class TestPostPools(TestPoolsEndpoint): ] }, 409, - {}, + {"detail": "Unique constraint violation"}, ), ( { @@ -464,7 +445,7 @@ class TestPostPools(TestPoolsEndpoint): ] }, 409, - {}, + {"detail": "Unique constraint violation"}, ), ], ) @@ -473,10 +454,8 @@ def test_post_pools(self, test_client, session, body, expected_status_code, expe n_pools = session.query(Pool).count() response = test_client.post("/public/pools/bulk", json=body) assert response.status_code == expected_status_code - response_json = response.json() + assert response.json() == expected_response if expected_status_code == 201: - assert response_json == expected_response assert session.query(Pool).count() == n_pools + 2 else: - # since different database backend return different error messages, we just check the status code assert session.query(Pool).count() == n_pools From c9b67c1a6db9cf166711e17be94695736172a13a Mon Sep 17 00:00:00 2001 From: jason810496 Date: Fri, 22 Nov 2024 15:54:59 +0800 Subject: [PATCH 9/9] Fix bound for TypeVar, type for comment --- airflow/api_fastapi/common/exceptions.py | 2 +- airflow/api_fastapi/core_api/routes/public/pools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/api_fastapi/common/exceptions.py b/airflow/api_fastapi/common/exceptions.py index 98d0c907ee732..1e779a6097576 100644 --- a/airflow/api_fastapi/common/exceptions.py +++ b/airflow/api_fastapi/common/exceptions.py @@ -23,7 +23,7 @@ from fastapi import HTTPException, Request, status from sqlalchemy.exc import IntegrityError -T = TypeVar("T") +T = TypeVar("T", bound=Exception) class BaseErrorHandler(Generic[T], ABC): diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index ad2ba40ae2fca..0e67994acfaab 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -163,7 +163,7 @@ def patch_pool( status_code=status.HTTP_201_CREATED, responses=create_openapi_http_exception_doc( [status.HTTP_409_CONFLICT] - ), # handle by global exception handler + ), # handled by global exception handler ) def post_pool( body: PoolPostBody, @@ -181,7 +181,7 @@ def post_pool( status_code=status.HTTP_201_CREATED, responses=create_openapi_http_exception_doc( [ - status.HTTP_409_CONFLICT, # handle by global exception handler + status.HTTP_409_CONFLICT, # handled by global exception handler ] ), )