Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-81 Add Insert Multiple Pools API #44121

9 changes: 8 additions & 1 deletion airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from fastapi import FastAPI
from starlette.routing import Mount

from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views
from airflow.api_fastapi.core_api.app import (
init_config,
init_dag_bag,
init_error_handlers,
init_plugins,
init_views,
)
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.configuration import conf
Expand Down Expand Up @@ -61,6 +67,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_dag_bag(app)
init_views(app)
init_plugins(app)
init_error_handlers(app)
init_auth_manager()

if "execution" in apps_list or "all" in apps_list:
Expand Down
64 changes: 64 additions & 0 deletions airflow/api_fastapi/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from fastapi import HTTPException, Request, status
from sqlalchemy.exc import IntegrityError

T = TypeVar("T", bound=Exception)


class BaseErrorHandler(Generic[T], ABC):
"""Base class for error handlers."""

def __init__(self, exception_cls: T) -> None:
self.exception_cls = exception_cls

@abstractmethod
def exception_handler(self, request: Request, exc: T):
"""exception_handler method."""
raise NotImplementedError


class _UniqueConstraintErrorHandler(BaseErrorHandler[IntegrityError]):
"""Exception raised when trying to insert a duplicate value in a unique column."""

def __init__(self):
super().__init__(IntegrityError)
self.unique_constraint_error_messages = [
"UNIQUE constraint failed", # SQLite
"Duplicate entry", # MySQL
"violates unique constraint", # PostgreSQL
]

def exception_handler(self, request: Request, exc: IntegrityError):
"""Handle IntegrityError exception."""
exc_orig_str = str(exc.orig)
if any(error_msg in exc_orig_str for error_msg in self.unique_constraint_error_messages):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Unique constraint violation",
)


DatabaseErrorHandlers = [
_UniqueConstraintErrorHandler(),
]
8 changes: 8 additions & 0 deletions airflow/api_fastapi/core_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,11 @@ def init_config(app: FastAPI) -> None:
app.add_middleware(GZipMiddleware, minimum_size=1024, compresslevel=5)

app.state.secret_key = conf.get("webserver", "secret_key")


def init_error_handlers(app: FastAPI) -> None:
from airflow.api_fastapi.common.exceptions import DatabaseErrorHandlers

# register database error handlers
for handler in DatabaseErrorHandlers:
app.add_exception_handler(handler.exception_cls, handler.exception_handler)
8 changes: 7 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class PoolPatchBody(BaseModel):
class PoolPostBody(BasePool):
"""Pool serializer for post bodies."""

pool: str = Field(alias="name")
pool: str = Field(alias="name", max_length=256)
description: str | None = None
include_deferred: bool = False


class PoolPostBulkBody(BaseModel):
"""Pools serializer for post bodies."""

pools: list[PoolPostBody]
jason810496 marked this conversation as resolved.
Show resolved Hide resolved
63 changes: 63 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3278,6 +3278,56 @@ paths:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'409':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Conflict
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/pools/bulk:
post:
tags:
- Pool
summary: Post Pools
description: Create multiple pools.
operationId: post_pools
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/PoolPostBulkBody'
required: true
responses:
'201':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/PoolCollectionResponse'
'401':
description: Unauthorized
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'403':
description: Forbidden
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'409':
description: Conflict
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'422':
description: Validation Error
content:
Expand Down Expand Up @@ -6544,6 +6594,7 @@ components:
properties:
name:
type: string
maxLength: 256
title: Name
slots:
type: integer
Expand All @@ -6563,6 +6614,18 @@ components:
- slots
title: PoolPostBody
description: Pool serializer for post bodies.
PoolPostBulkBody:
properties:
pools:
items:
$ref: '#/components/schemas/PoolPostBody'
type: array
title: Pools
type: object
required:
- pools
title: PoolPostBulkBody
description: Pools serializer for post bodies.
PoolResponse:
properties:
name:
Expand Down
31 changes: 28 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PoolCollectionResponse,
PoolPatchBody,
PoolPostBody,
PoolPostBulkBody,
PoolResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
Expand Down Expand Up @@ -160,14 +161,38 @@ def patch_pool(
@pools_router.post(
"",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc(
[status.HTTP_409_CONFLICT]
), # handled by global exception handler
)
def post_pool(
post_body: PoolPostBody,
body: PoolPostBody,
session: Annotated[Session, Depends(get_session)],
) -> PoolResponse:
"""Create a Pool."""
pool = Pool(**post_body.model_dump())

pool = Pool(**body.model_dump())
session.add(pool)

return PoolResponse.model_validate(pool, from_attributes=True)


@pools_router.post(
"/bulk",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc(
[
status.HTTP_409_CONFLICT, # handled by global exception handler
]
),
)
def post_pools(
body: PoolPostBulkBody,
session: Annotated[Session, Depends(get_session)],
) -> PoolCollectionResponse:
"""Create multiple pools."""
pools = [Pool(**body.model_dump()) for body in body.pools]
session.add_all(pools)
return PoolCollectionResponse(
pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools],
total_entries=len(pools),
)
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,9 @@ export type DagRunServiceClearDagRunMutationResult = Awaited<
export type PoolServicePostPoolMutationResult = Awaited<
ReturnType<typeof PoolService.postPool>
>;
export type PoolServicePostPoolsMutationResult = Awaited<
ReturnType<typeof PoolService.postPools>
>;
export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited<
ReturnType<typeof TaskInstanceService.getTaskInstancesBatch>
>;
Expand Down
38 changes: 38 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
DagWarningType,
PoolPatchBody,
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
VariableBody,
} from "../requests/types.gen";
Expand Down Expand Up @@ -2363,6 +2364,43 @@ export const usePoolServicePostPool = <
PoolService.postPool({ requestBody }) as unknown as Promise<TData>,
...options,
});
/**
* Post Pools
* Create multiple pools.
* @param data The data for the request.
* @param data.requestBody
* @returns PoolCollectionResponse Successful Response
* @throws ApiError
*/
export const usePoolServicePostPools = <
TData = Common.PoolServicePostPoolsMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
requestBody: PoolPostBulkBody;
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
requestBody: PoolPostBulkBody;
},
TContext
>({
mutationFn: ({ requestBody }) =>
PoolService.postPools({ requestBody }) as unknown as Promise<TData>,
...options,
});
/**
* Get Task Instances Batch
* Get list of task instances.
Expand Down
17 changes: 17 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,7 @@ export const $PoolPostBody = {
properties: {
name: {
type: "string",
maxLength: 256,
title: "Name",
},
slots: {
Expand Down Expand Up @@ -2877,6 +2878,22 @@ export const $PoolPostBody = {
description: "Pool serializer for post bodies.",
} as const;

export const $PoolPostBulkBody = {
properties: {
pools: {
items: {
$ref: "#/components/schemas/PoolPostBody",
},
type: "array",
title: "Pools",
},
},
type: "object",
required: ["pools"],
title: "PoolPostBulkBody",
description: "Pools serializer for post bodies.",
} as const;

export const $PoolResponse = {
properties: {
name: {
Expand Down
28 changes: 28 additions & 0 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ import type {
GetPoolsResponse,
PostPoolData,
PostPoolResponse,
PostPoolsData,
PostPoolsResponse,
GetProvidersData,
GetProvidersResponse,
GetTaskInstanceData,
Expand Down Expand Up @@ -1790,6 +1792,32 @@ export class PoolService {
errors: {
401: "Unauthorized",
403: "Forbidden",
409: "Conflict",
422: "Validation Error",
},
});
}

/**
* Post Pools
* Create multiple pools.
* @param data The data for the request.
* @param data.requestBody
* @returns PoolCollectionResponse Successful Response
* @throws ApiError
*/
public static postPools(
data: PostPoolsData,
): CancelablePromise<PostPoolsResponse> {
return __request(OpenAPI, {
method: "POST",
url: "/public/pools/bulk",
body: data.requestBody,
mediaType: "application/json",
errors: {
401: "Unauthorized",
403: "Forbidden",
409: "Conflict",
422: "Validation Error",
},
});
Expand Down
Loading