Skip to content

Commit

Permalink
Move service/ to services/ui/, make sort_param unique for dag_run_sor…
Browse files Browse the repository at this point in the history
…t_param, remove unreachable statement from parameters.py::_transform_ti_states, make include upstream/downstream Annotated optional and include new test for upstream/downstream
  • Loading branch information
bugraoz93 committed Dec 8, 2024
1 parent cd1c8e3 commit a72e063
Show file tree
Hide file tree
Showing 14 changed files with 568 additions and 129 deletions.
106 changes: 36 additions & 70 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.typing_compat import Self
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from sqlalchemy.sql import ColumnElement, Select
Expand Down Expand Up @@ -112,69 +113,6 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter:
return self.set_value(only_active)


class DagIdsFilter(BaseParam[list[str]]):
"""Filter on dag ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.dag_id.in_(self.value))
return select

def depends(self, dag_ids: list[str] = Query(None)) -> DagIdsFilter:
return self.set_value(dag_ids)


class DagRunIdsFilter(BaseParam[list[str]]):
"""Filter on dag run ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.run_id.in_(self.value))
return select

def depends(self, dag_run_ids: list[str] = Query(None)) -> DagRunIdsFilter:
return self.set_value(dag_run_ids)


class DagRunRunTypesFilter(BaseParam[Optional[list[str]]]):
"""Filter on dag run run_types."""

def __init__(self, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(DagRun.run_type.in_(self.value))
return select

def depends(self, run_types: list[str] = Query(None)) -> DagRunRunTypesFilter:
return self.set_value(run_types)


class TaskIdsFilter(BaseParam[list[str]]):
"""Filter on task ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.task_id.in_(self.value))
return select

def depends(self, task_ids: list[str] = Query(None)) -> TaskIdsFilter:
return self.set_value(task_ids)


class _SearchParam(BaseParam[str]):
"""Search on attribute."""

Expand Down Expand Up @@ -311,7 +249,6 @@ def to_orm(self, select: Select) -> Select:
return select
if self.value is None and self.skip_none:
return select

if isinstance(self.value, list):
if self.filter_option == FilterOptionEnum.IN:
return select.where(self.attribute.in_(self.value))
Expand Down Expand Up @@ -547,10 +484,9 @@ def depends_float(

# DagRun
QueryLastDagRunStateFilter = Annotated[
FilterParam[Optional[DagRunState]],
FilterParam[Union[DagRunState, None]],
Depends(filter_param_factory(DagRun.state, Optional[DagRunState], filter_name="last_dag_run_state")),
]
QueryDagRunRunTypesFilter = Annotated[DagRunRunTypesFilter, Depends(DagRunRunTypesFilter().depends)]


def _transform_dag_run_states(states: Iterable[str] | None) -> list[DagRunState | None] | None:
Expand Down Expand Up @@ -578,6 +514,32 @@ def _transform_dag_run_states(states: Iterable[str] | None) -> list[DagRunState
),
]


def _transform_dag_run_types(types: list[str] | None) -> list[DagRunType | None] | None:
try:
if not types:
return None
return [None if run_type in ("none", None) else DagRunType(run_type) for run_type in types]
except ValueError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid value for state. Valid values are {', '.join(DagRunType)}",
)


QueryDagRunRunTypesFilter = Annotated[
FilterParam[list[str]],
Depends(
filter_param_factory(
attribute=DagRun.run_type,
_type=list[str],
filter_option=FilterOptionEnum.ANY_EQUAL,
default_factory=list,
transform_callable=_transform_dag_run_types,
)
),
]

# DAGTags
QueryDagTagPatternSearch = Annotated[
_SearchParam, Depends(search_param_factory(DagTag.name, "tag_name_pattern"))
Expand All @@ -595,7 +557,6 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid value for state. Valid values are {', '.join(TaskInstanceState)}",
)
return states


QueryTIStateFilter = Annotated[
Expand Down Expand Up @@ -637,6 +598,11 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N
_DagIdAssetReferenceFilter, Depends(_DagIdAssetReferenceFilter().depends)
]

# UI Shared
QueryIncludeUpstream = Annotated[Union[bool, None], Depends(lambda: False)]
QueryIncludeDownstream = Annotated[Union[bool, None], Depends(lambda: False)]

# DAG Filter Upstream|Downstream
def _optional_boolean(value: bool | None) -> bool | None:
return value if value is not None else False


QueryIncludeUpstream = Annotated[Union[bool, None], AfterValidator(_optional_boolean)]
QueryIncludeDownstream = Annotated[Union[bool, None], AfterValidator(_optional_boolean)]
40 changes: 38 additions & 2 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,24 @@ paths:
schema:
type: string
title: Dag Id
- name: include_upstream
in: query
required: false
schema:
anyOf:
- type: boolean
- type: 'null'
default: false
title: Include Upstream
- name: include_downstream
in: query
required: false
schema:
anyOf:
- type: boolean
- type: 'null'
default: false
title: Include Downstream
- name: root
in: query
required: false
Expand Down Expand Up @@ -249,6 +267,24 @@ paths:
schema:
type: string
title: Dag Id
- name: include_upstream
in: query
required: false
schema:
anyOf:
- type: boolean
- type: 'null'
default: false
title: Include Upstream
- name: include_downstream
in: query
required: false
schema:
anyOf:
- type: boolean
- type: 'null'
default: false
title: Include Downstream
- name: base_date
in: query
required: false
Expand All @@ -265,14 +301,14 @@ paths:
- type: string
- type: 'null'
title: Root
- name: run_types
- name: run_type
in: query
required: false
schema:
type: array
items:
type: string
title: Run Types
title: Run Type
- name: state
in: query
required: false
Expand Down
27 changes: 19 additions & 8 deletions airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
GridResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.routes.ui.service.grid import (
from airflow.api_fastapi.core_api.services.ui.grid import (
fill_task_instance_summaries,
get_dag_run_sort_param,
get_task_group_map,
Expand All @@ -67,19 +67,15 @@ def grid_data(
offset: QueryOffset,
request: Request,
num_runs: QueryLimit,
include_upstream: QueryIncludeUpstream,
include_downstream: QueryIncludeDownstream,
include_upstream: QueryIncludeUpstream = False,
include_downstream: QueryIncludeDownstream = False,
base_date: OptionalDateTimeQuery = None,
root: str | None = None,
) -> GridResponse:
"""Return grid data."""
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
if root:
dag = dag.partial_subset(
task_ids_or_regex=root, include_upstream=include_upstream, include_downstream=include_downstream
)

current_time = timezone.utcnow()
# Retrieve, sort and encode the previous DAG Runs
Expand Down Expand Up @@ -141,17 +137,33 @@ def grid_data(
task_instances = session.execute(tis_of_dag_runs)

# Generate Grouped Task Instances
task_node_map_exclude = None
if root:
task_node_map_exclude = get_task_group_map(
dag=dag.partial_subset(
task_ids_or_regex=root,
include_upstream=include_upstream,
include_downstream=include_downstream,
)
)

task_node_map = get_task_group_map(dag=dag)
parent_tis: dict[tuple[str, str], list] = collections.defaultdict(list)
all_tis: dict[tuple[str, str], list] = collections.defaultdict(list)
for ti in task_instances:
if task_node_map_exclude and ti.task_id not in task_node_map_exclude.keys():
continue
all_tis[(ti.task_id, ti.run_id)].append(ti)
parent_id = task_node_map[ti.task_id]["parent_id"]
if not parent_id and task_node_map[ti.task_id]["is_group"]:
parent_tis[(ti.task_id, ti.run_id)].append(ti)
elif parent_id and task_node_map[parent_id]["is_group"]:
parent_tis[(parent_id, ti.run_id)].append(ti)

# Clear task_node_map_exclude to free up memory
if task_node_map_exclude:
task_node_map_exclude.clear()

# Extend subgroup task instances to parent task instances to calculate the aggregates states
task_group_map = {k: v for k, v in task_node_map.items() if v["is_group"]}
parent_tis.update(
Expand All @@ -163,7 +175,6 @@ def grid_data(
if task_id_parent == task_map["parent_id"]
}
)

# Create the Task Instance Summaries to be used in the Grid Response
task_instance_summaries: dict[str, list] = {
run_id: [] for (_, run_id), _ in itertools.chain(parent_tis.items(), all_tis.items())
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/ui/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def structure_data(
session: SessionDep,
dag_id: str,
request: Request,
include_upstream: QueryIncludeUpstream,
include_downstream: QueryIncludeDownstream,
include_upstream: QueryIncludeUpstream = False,
include_downstream: QueryIncludeDownstream = False,
root: str | None = None,
) -> StructureDataResponse:
"""Get Structure Data."""
Expand Down
16 changes: 16 additions & 0 deletions airflow/api_fastapi/core_api/services/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit a72e063

Please sign in to comment.