diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index de7350a45add2..fa193954a93e1 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -47,6 +47,7 @@ from airflow.typing_compat import Self from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.types import DagRunType if TYPE_CHECKING: from sqlalchemy.sql import ColumnElement, Select @@ -112,69 +113,6 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter: return self.set_value(only_active) -class DagIdsFilter(BaseParam[list[str]]): - """Filter on dag ids.""" - - def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: - super().__init__(value, skip_none) - self.model = model - - def to_orm(self, select: Select) -> Select: - if self.value and self.skip_none: - return select.where(self.model.dag_id.in_(self.value)) - return select - - def depends(self, dag_ids: list[str] = Query(None)) -> DagIdsFilter: - return self.set_value(dag_ids) - - -class DagRunIdsFilter(BaseParam[list[str]]): - """Filter on dag run ids.""" - - def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: - super().__init__(value, skip_none) - self.model = model - - def to_orm(self, select: Select) -> Select: - if self.value and self.skip_none: - return select.where(self.model.run_id.in_(self.value)) - return select - - def depends(self, dag_run_ids: list[str] = Query(None)) -> DagRunIdsFilter: - return self.set_value(dag_run_ids) - - -class DagRunRunTypesFilter(BaseParam[Optional[list[str]]]): - """Filter on dag run run_types.""" - - def __init__(self, value: list[str] | None = None, skip_none: bool = True) -> None: - super().__init__(value, skip_none) - - def to_orm(self, select: Select) -> Select: - if self.value and self.skip_none: - return select.where(DagRun.run_type.in_(self.value)) - return select - - def depends(self, run_types: list[str] = Query(None)) -> DagRunRunTypesFilter: - return self.set_value(run_types) - - -class TaskIdsFilter(BaseParam[list[str]]): - """Filter on task ids.""" - - def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None: - super().__init__(value, skip_none) - self.model = model - - def to_orm(self, select: Select) -> Select: - if self.value and self.skip_none: - return select.where(self.model.task_id.in_(self.value)) - return select - - def depends(self, task_ids: list[str] = Query(None)) -> TaskIdsFilter: - return self.set_value(task_ids) - - class _SearchParam(BaseParam[str]): """Search on attribute.""" @@ -311,7 +249,6 @@ def to_orm(self, select: Select) -> Select: return select if self.value is None and self.skip_none: return select - if isinstance(self.value, list): if self.filter_option == FilterOptionEnum.IN: return select.where(self.attribute.in_(self.value)) @@ -547,10 +484,9 @@ def depends_float( # DagRun QueryLastDagRunStateFilter = Annotated[ - FilterParam[Optional[DagRunState]], + FilterParam[Union[DagRunState, None]], Depends(filter_param_factory(DagRun.state, Optional[DagRunState], filter_name="last_dag_run_state")), ] -QueryDagRunRunTypesFilter = Annotated[DagRunRunTypesFilter, Depends(DagRunRunTypesFilter().depends)] def _transform_dag_run_states(states: Iterable[str] | None) -> list[DagRunState | None] | None: @@ -578,6 +514,32 @@ def _transform_dag_run_states(states: Iterable[str] | None) -> list[DagRunState ), ] + +def _transform_dag_run_types(types: list[str] | None) -> list[DagRunType | None] | None: + try: + if not types: + return None + return [None if run_type in ("none", None) else DagRunType(run_type) for run_type in types] + except ValueError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid value for state. Valid values are {', '.join(DagRunType)}", + ) + + +QueryDagRunRunTypesFilter = Annotated[ + FilterParam[list[str]], + Depends( + filter_param_factory( + attribute=DagRun.run_type, + _type=list[str], + filter_option=FilterOptionEnum.ANY_EQUAL, + default_factory=list, + transform_callable=_transform_dag_run_types, + ) + ), +] + # DAGTags QueryDagTagPatternSearch = Annotated[ _SearchParam, Depends(search_param_factory(DagTag.name, "tag_name_pattern")) @@ -595,7 +557,6 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Invalid value for state. Valid values are {', '.join(TaskInstanceState)}", ) - return states QueryTIStateFilter = Annotated[ @@ -637,6 +598,11 @@ def _transform_ti_states(states: list[str] | None) -> list[TaskInstanceState | N _DagIdAssetReferenceFilter, Depends(_DagIdAssetReferenceFilter().depends) ] -# UI Shared -QueryIncludeUpstream = Annotated[Union[bool, None], Depends(lambda: False)] -QueryIncludeDownstream = Annotated[Union[bool, None], Depends(lambda: False)] + +# DAG Filter Upstream|Downstream +def _optional_boolean(value: bool | None) -> bool | None: + return value if value is not None else False + + +QueryIncludeUpstream = Annotated[Union[bool, None], AfterValidator(_optional_boolean)] +QueryIncludeDownstream = Annotated[Union[bool, None], AfterValidator(_optional_boolean)] diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index bca5e0a5a863d..f88a7cbabd254 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -208,6 +208,24 @@ paths: schema: type: string title: Dag Id + - name: include_upstream + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + default: false + title: Include Upstream + - name: include_downstream + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + default: false + title: Include Downstream - name: root in: query required: false @@ -249,6 +267,24 @@ paths: schema: type: string title: Dag Id + - name: include_upstream + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + default: false + title: Include Upstream + - name: include_downstream + in: query + required: false + schema: + anyOf: + - type: boolean + - type: 'null' + default: false + title: Include Downstream - name: base_date in: query required: false @@ -265,14 +301,14 @@ paths: - type: string - type: 'null' title: Root - - name: run_types + - name: run_type in: query required: false schema: type: array items: type: string - title: Run Types + title: Run Type - name: state in: query required: false diff --git a/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow/api_fastapi/core_api/routes/ui/grid.py index a5d28fb819645..a694dd7ec35bd 100644 --- a/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -41,7 +41,7 @@ GridResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.routes.ui.service.grid import ( +from airflow.api_fastapi.core_api.services.ui.grid import ( fill_task_instance_summaries, get_dag_run_sort_param, get_task_group_map, @@ -67,8 +67,8 @@ def grid_data( offset: QueryOffset, request: Request, num_runs: QueryLimit, - include_upstream: QueryIncludeUpstream, - include_downstream: QueryIncludeDownstream, + include_upstream: QueryIncludeUpstream = False, + include_downstream: QueryIncludeDownstream = False, base_date: OptionalDateTimeQuery = None, root: str | None = None, ) -> GridResponse: @@ -76,10 +76,6 @@ def grid_data( dag: DAG = request.app.state.dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") - if root: - dag = dag.partial_subset( - task_ids_or_regex=root, include_upstream=include_upstream, include_downstream=include_downstream - ) current_time = timezone.utcnow() # Retrieve, sort and encode the previous DAG Runs @@ -141,10 +137,22 @@ def grid_data( task_instances = session.execute(tis_of_dag_runs) # Generate Grouped Task Instances + task_node_map_exclude = None + if root: + task_node_map_exclude = get_task_group_map( + dag=dag.partial_subset( + task_ids_or_regex=root, + include_upstream=include_upstream, + include_downstream=include_downstream, + ) + ) + task_node_map = get_task_group_map(dag=dag) parent_tis: dict[tuple[str, str], list] = collections.defaultdict(list) all_tis: dict[tuple[str, str], list] = collections.defaultdict(list) for ti in task_instances: + if task_node_map_exclude and ti.task_id not in task_node_map_exclude.keys(): + continue all_tis[(ti.task_id, ti.run_id)].append(ti) parent_id = task_node_map[ti.task_id]["parent_id"] if not parent_id and task_node_map[ti.task_id]["is_group"]: @@ -152,6 +160,10 @@ def grid_data( elif parent_id and task_node_map[parent_id]["is_group"]: parent_tis[(parent_id, ti.run_id)].append(ti) + # Clear task_node_map_exclude to free up memory + if task_node_map_exclude: + task_node_map_exclude.clear() + # Extend subgroup task instances to parent task instances to calculate the aggregates states task_group_map = {k: v for k, v in task_node_map.items() if v["is_group"]} parent_tis.update( @@ -163,7 +175,6 @@ def grid_data( if task_id_parent == task_map["parent_id"] } ) - # Create the Task Instance Summaries to be used in the Grid Response task_instance_summaries: dict[str, list] = { run_id: [] for (_, run_id), _ in itertools.chain(parent_tis.items(), all_tis.items()) diff --git a/airflow/api_fastapi/core_api/routes/ui/structure.py b/airflow/api_fastapi/core_api/routes/ui/structure.py index 98fb6bd85a921..75991b24b52f2 100644 --- a/airflow/api_fastapi/core_api/routes/ui/structure.py +++ b/airflow/api_fastapi/core_api/routes/ui/structure.py @@ -38,8 +38,8 @@ def structure_data( session: SessionDep, dag_id: str, request: Request, - include_upstream: QueryIncludeUpstream, - include_downstream: QueryIncludeDownstream, + include_upstream: QueryIncludeUpstream = False, + include_downstream: QueryIncludeDownstream = False, root: str | None = None, ) -> StructureDataResponse: """Get Structure Data.""" diff --git a/airflow/api_fastapi/core_api/routes/ui/service/__init__.py b/airflow/api_fastapi/core_api/services/__init__.py similarity index 100% rename from airflow/api_fastapi/core_api/routes/ui/service/__init__.py rename to airflow/api_fastapi/core_api/services/__init__.py diff --git a/airflow/api_fastapi/core_api/services/ui/__init__.py b/airflow/api_fastapi/core_api/services/ui/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_fastapi/core_api/services/ui/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_fastapi/core_api/routes/ui/service/grid.py b/airflow/api_fastapi/core_api/services/ui/grid.py similarity index 88% rename from airflow/api_fastapi/core_api/routes/ui/service/grid.py rename to airflow/api_fastapi/core_api/services/ui/grid.py index 6c0af7e5c8448..b2cc8c43bf39a 100644 --- a/airflow/api_fastapi/core_api/routes/ui/service/grid.py +++ b/airflow/api_fastapi/core_api/services/ui/grid.py @@ -54,13 +54,15 @@ def get_dag_run_sort_param(dag: DAG) -> BaseParam: :return: Sort Param """ + sort_param = SortParam( + allowed_attrs=["logical_date", "data_interval_start", "data_interval_end"], model=DagRun + ) + for name in dag.timetable.run_ordering: if name in ("data_interval_start", "data_interval_end"): - return SortParam( - allowed_attrs=["logical_date", "data_interval_start", "data_interval_end"], model=DagRun - ).set_value(name) + return sort_param.set_value(name) else: - return SortParam(allowed_attrs=["logical_date"], model=DagRun).set_value("logical_date") + return sort_param.set_value("logical_date") raise AirflowConfigException(f"No valid sort column found in run_ordering for {dag.dag_id}") @@ -96,15 +98,17 @@ def _fill_task_group_map( """Recursively fill the Task Group Map.""" if task_node is None: return - if isinstance(task_node, MappedOperator): task_nodes[task_node.node_id] = { "is_group": False, "parent_id": parent_node.node_id if parent_node else None, - "task_count": task_nodes[parent_node.node_id]["task_count"] - if _is_task_node_mapped_task_group(parent_node) and parent_node - else task_node, + "task_count": [task_node], } + if isinstance(parent_node, TaskGroup): + # Remove the regular task counted in parent_node + task_nodes[parent_node.node_id]["task_count"].append(-1) + # Add the mapped task to the parent_node + task_nodes[parent_node.node_id]["task_count"].append(task_node) return elif isinstance(task_node, BaseOperator): task_nodes[task_node.task_id] = { @@ -112,16 +116,16 @@ def _fill_task_group_map( "parent_id": parent_node.node_id if parent_node else None, "task_count": task_nodes[parent_node.node_id]["task_count"] if _is_task_node_mapped_task_group(parent_node) and parent_node - else 1, + else [1], } return elif isinstance(task_node, TaskGroup): task_nodes[task_node.node_id] = { "is_group": True, "parent_id": parent_node.node_id if parent_node else None, - "task_count": task_node + "task_count": [task_node] if _is_task_node_mapped_task_group(task_node) - else len([child for child in get_task_group_children_getter()(task_node)]), + else [len([child for child in get_task_group_children_getter()(task_node)])], } return [ _fill_task_group_map(task_node=child, parent_node=task_node) @@ -188,6 +192,11 @@ def fill_task_instance_summaries( ) # Task Count is either integer or a TaskGroup to get the task count task_count = task_node_map[task_id]["task_count"] + final_task_count = sum( + node if isinstance(node, int) else node.get_mapped_ti_count(run_id=run_id, session=session) + for node in task_count + ) + task_instance_summaries_to_fill[run_id].append( GridTaskInstanceSummary( task_id=task_id, @@ -196,9 +205,7 @@ def fill_task_instance_summaries( end_date=ti_end_date, queued_dttm=ti_queued_dttm, states=all_states, - task_count=task_count - if type(task_count) is int - else task_count.get_mapped_ti_count(run_id=run_id, session=session), + task_count=final_task_count, overall_state=overall_state, note=ti_note, ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 15f918a2a39fd..64453d76e3f86 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -340,13 +340,20 @@ export const useStructureServiceStructureDataKey = export const UseStructureServiceStructureDataKeyFn = ( { dagId, + includeDownstream, + includeUpstream, root, }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; root?: string; }, queryKey?: Array, -) => [useStructureServiceStructureDataKey, ...(queryKey ?? [{ dagId, root }])]; +) => [ + useStructureServiceStructureDataKey, + ...(queryKey ?? [{ dagId, includeDownstream, includeUpstream, root }]), +]; export type GridServiceGridDataDefaultResponse = Awaited< ReturnType >; @@ -359,24 +366,40 @@ export const UseGridServiceGridDataKeyFn = ( { baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }: { baseDate?: string; dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; root?: string; - runTypes?: string[]; + runType?: string[]; state?: string[]; }, queryKey?: Array, ) => [ useGridServiceGridDataKey, - ...(queryKey ?? [{ baseDate, dagId, limit, offset, root, runTypes, state }]), + ...(queryKey ?? [ + { + baseDate, + dagId, + includeDownstream, + includeUpstream, + limit, + offset, + root, + runType, + state, + }, + ]), ]; export type BackfillServiceListBackfillsDefaultResponse = Awaited< ReturnType diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index 06bddecfca3ae..77e9976d69381 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -412,6 +412,8 @@ export const prefetchUseDashboardServiceHistoricalMetrics = ( * Get Structure Data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.root * @returns StructureDataResponse Successful Response * @throws ApiError @@ -420,24 +422,41 @@ export const prefetchUseStructureServiceStructureData = ( queryClient: QueryClient, { dagId, + includeDownstream, + includeUpstream, root, }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; root?: string; }, ) => queryClient.prefetchQuery({ - queryKey: Common.UseStructureServiceStructureDataKeyFn({ dagId, root }), - queryFn: () => StructureService.structureData({ dagId, root }), + queryKey: Common.UseStructureServiceStructureDataKeyFn({ + dagId, + includeDownstream, + includeUpstream, + root, + }), + queryFn: () => + StructureService.structureData({ + dagId, + includeDownstream, + includeUpstream, + root, + }), }); /** * Grid Data * Return grid data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.baseDate * @param data.root - * @param data.runTypes + * @param data.runType * @param data.state * @param data.offset * @param data.limit @@ -449,18 +468,22 @@ export const prefetchUseGridServiceGridData = ( { baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }: { baseDate?: string; dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; root?: string; - runTypes?: string[]; + runType?: string[]; state?: string[]; }, ) => @@ -468,20 +491,24 @@ export const prefetchUseGridServiceGridData = ( queryKey: Common.UseGridServiceGridDataKeyFn({ baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }), queryFn: () => GridService.gridData({ baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }), }); diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index c2c01f990e87a..216c6f75ce362 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -528,6 +528,8 @@ export const useDashboardServiceHistoricalMetrics = < * Get Structure Data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.root * @returns StructureDataResponse Successful Response * @throws ApiError @@ -539,9 +541,13 @@ export const useStructureServiceStructureData = < >( { dagId, + includeDownstream, + includeUpstream, root, }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; root?: string; }, queryKey?: TQueryKey, @@ -549,10 +555,16 @@ export const useStructureServiceStructureData = < ) => useQuery({ queryKey: Common.UseStructureServiceStructureDataKeyFn( - { dagId, root }, + { dagId, includeDownstream, includeUpstream, root }, queryKey, ), - queryFn: () => StructureService.structureData({ dagId, root }) as TData, + queryFn: () => + StructureService.structureData({ + dagId, + includeDownstream, + includeUpstream, + root, + }) as TData, ...options, }); /** @@ -560,9 +572,11 @@ export const useStructureServiceStructureData = < * Return grid data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.baseDate * @param data.root - * @param data.runTypes + * @param data.runType * @param data.state * @param data.offset * @param data.limit @@ -577,18 +591,22 @@ export const useGridServiceGridData = < { baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }: { baseDate?: string; dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; root?: string; - runTypes?: string[]; + runType?: string[]; state?: string[]; }, queryKey?: TQueryKey, @@ -596,17 +614,29 @@ export const useGridServiceGridData = < ) => useQuery({ queryKey: Common.UseGridServiceGridDataKeyFn( - { baseDate, dagId, limit, offset, root, runTypes, state }, + { + baseDate, + dagId, + includeDownstream, + includeUpstream, + limit, + offset, + root, + runType, + state, + }, queryKey, ), queryFn: () => GridService.gridData({ baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }) as TData, ...options, diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index aef66dbc57aec..bf65ad4d9b2b9 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -503,6 +503,8 @@ export const useDashboardServiceHistoricalMetricsSuspense = < * Get Structure Data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.root * @returns StructureDataResponse Successful Response * @throws ApiError @@ -514,9 +516,13 @@ export const useStructureServiceStructureDataSuspense = < >( { dagId, + includeDownstream, + includeUpstream, root, }: { dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; root?: string; }, queryKey?: TQueryKey, @@ -524,10 +530,16 @@ export const useStructureServiceStructureDataSuspense = < ) => useSuspenseQuery({ queryKey: Common.UseStructureServiceStructureDataKeyFn( - { dagId, root }, + { dagId, includeDownstream, includeUpstream, root }, queryKey, ), - queryFn: () => StructureService.structureData({ dagId, root }) as TData, + queryFn: () => + StructureService.structureData({ + dagId, + includeDownstream, + includeUpstream, + root, + }) as TData, ...options, }); /** @@ -535,9 +547,11 @@ export const useStructureServiceStructureDataSuspense = < * Return grid data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.baseDate * @param data.root - * @param data.runTypes + * @param data.runType * @param data.state * @param data.offset * @param data.limit @@ -552,18 +566,22 @@ export const useGridServiceGridDataSuspense = < { baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }: { baseDate?: string; dagId: string; + includeDownstream?: boolean; + includeUpstream?: boolean; limit?: number; offset?: number; root?: string; - runTypes?: string[]; + runType?: string[]; state?: string[]; }, queryKey?: TQueryKey, @@ -571,17 +589,29 @@ export const useGridServiceGridDataSuspense = < ) => useSuspenseQuery({ queryKey: Common.UseGridServiceGridDataKeyFn( - { baseDate, dagId, limit, offset, root, runTypes, state }, + { + baseDate, + dagId, + includeDownstream, + includeUpstream, + limit, + offset, + root, + runType, + state, + }, queryKey, ), queryFn: () => GridService.gridData({ baseDate, dagId, + includeDownstream, + includeUpstream, limit, offset, root, - runTypes, + runType, state, }) as TData, ...options, diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 765f32c23d120..13598b47799b1 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -673,6 +673,8 @@ export class StructureService { * Get Structure Data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.root * @returns StructureDataResponse Successful Response * @throws ApiError @@ -685,6 +687,8 @@ export class StructureService { url: "/ui/structure/structure_data", query: { dag_id: data.dagId, + include_upstream: data.includeUpstream, + include_downstream: data.includeDownstream, root: data.root, }, errors: { @@ -701,9 +705,11 @@ export class GridService { * Return grid data. * @param data The data for the request. * @param data.dagId + * @param data.includeUpstream + * @param data.includeDownstream * @param data.baseDate * @param data.root - * @param data.runTypes + * @param data.runType * @param data.state * @param data.offset * @param data.limit @@ -720,9 +726,11 @@ export class GridService { dag_id: data.dagId, }, query: { + include_upstream: data.includeUpstream, + include_downstream: data.includeDownstream, base_date: data.baseDate, root: data.root, - run_types: data.runTypes, + run_type: data.runType, state: data.state, offset: data.offset, limit: data.limit, diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index bb3f597c1f94f..b2414639db825 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1460,6 +1460,8 @@ export type HistoricalMetricsResponse = HistoricalMetricDataResponse; export type StructureDataData = { dagId: string; + includeDownstream?: boolean | null; + includeUpstream?: boolean | null; root?: string | null; }; @@ -1468,10 +1470,12 @@ export type StructureDataResponse2 = StructureDataResponse; export type GridDataData = { baseDate?: string | null; dagId: string; + includeDownstream?: boolean | null; + includeUpstream?: boolean | null; limit?: number; offset?: number; root?: string | null; - runTypes?: Array; + runType?: Array; state?: Array; }; diff --git a/tests/api_fastapi/core_api/routes/ui/test_grid.py b/tests/api_fastapi/core_api/routes/ui/test_grid.py index c6b8dc241d7e1..2871c7a330142 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_grid.py +++ b/tests/api_fastapi/core_api/routes/ui/test_grid.py @@ -40,6 +40,9 @@ DAG_ID_2 = "test_dag_2" TASK_ID = "task" TASK_ID_2 = "task2" +SUB_TASK_ID = "subtask" +MAPPED_TASK_ID = "mapped_task" +TASK_GROUP_ID = "task_group" @pytest.fixture(autouse=True, scope="module") @@ -60,11 +63,11 @@ def setup(dag_maker, session=None): @task_group def mapped_task_group(arg1): - return MockOperator(task_id="subtask", arg1=arg1) + return MockOperator(task_id=SUB_TASK_ID, arg1=arg1) mapped_task_group.expand(arg1=["a", "b", "c"]) - with TaskGroup(group_id="task_group"): - MockOperator.partial(task_id="mapped_task").expand(arg1=["a", "b", "c", "d"]) + with TaskGroup(group_id=TASK_GROUP_ID): + MockOperator.partial(task_id=MAPPED_TASK_ID).expand(arg1=["a", "b", "c", "d"]) triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} logical_date = timezone.datetime(2024, 11, 30) @@ -187,7 +190,7 @@ def test_should_response_200(self, test_client): "up_for_retry": 0, "upstream_failed": 0, }, - "task_count": 1, + "task_count": 4, "task_id": "task_group", "try_number": 0, }, @@ -326,7 +329,7 @@ def test_should_response_200(self, test_client): "up_for_retry": 0, "upstream_failed": 0, }, - "task_count": 1, + "task_count": 4, "task_id": "task_group", "try_number": 0, }, @@ -411,6 +414,284 @@ def test_should_response_200(self, test_client): ], } + def test_should_response_200_include_upstream(self, test_client): + response = test_client.get( + f"/ui/grid/{DAG_ID}?root={SUB_TASK_ID}&include_upstream=true&include_downstream=false" + ) + assert response.status_code == 200 + print(response.json()) + assert response.json() == { + "dag_runs": [ + { + "data_interval_end": "2024-11-30T00:00:00Z", + "data_interval_start": "2024-11-29T00:00:00Z", + "end_date": "2024-12-01T00:00:00Z", + "note": None, + "queued_at": None, + "run_id": "run_1", + "run_type": "scheduled", + "start_date": "2016-01-01T00:00:00Z", + "state": "success", + "task_instances": [ + { + "end_date": None, + "note": None, + "overall_state": "success", + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 0, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 3, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group", + "try_number": 0, + }, + { + "end_date": None, + "note": None, + "overall_state": "success", + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 0, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 3, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group.subtask", + "try_number": 0, + }, + ], + "version_number": None, + }, + { + "data_interval_end": "2024-11-30T00:00:00Z", + "data_interval_start": "2024-11-29T00:00:00Z", + "end_date": "2024-12-01T00:00:00Z", + "note": None, + "queued_at": None, + "run_id": "run_2", + "run_type": "manual", + "start_date": "2016-01-01T00:00:00Z", + "state": "failed", + "task_instances": [ + { + "end_date": None, + "note": None, + "overall_state": None, + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 3, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 0, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group", + "try_number": 0, + }, + { + "end_date": None, + "note": None, + "overall_state": None, + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 3, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 0, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group.subtask", + "try_number": 0, + }, + ], + "version_number": None, + }, + ], + } + + def test_should_response_200_include_downstream(self, test_client): + response = test_client.get( + f"/ui/grid/{DAG_ID}?root={SUB_TASK_ID}&include_upstream=false&include_downstream=true" + ) + assert response.status_code == 200 + print(response.json()) + assert response.json() == { + "dag_runs": [ + { + "data_interval_end": "2024-11-30T00:00:00Z", + "data_interval_start": "2024-11-29T00:00:00Z", + "end_date": "2024-12-01T00:00:00Z", + "note": None, + "queued_at": None, + "run_id": "run_1", + "run_type": "scheduled", + "start_date": "2016-01-01T00:00:00Z", + "state": "success", + "task_instances": [ + { + "end_date": None, + "note": None, + "overall_state": "success", + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 0, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 3, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group", + "try_number": 0, + }, + { + "end_date": None, + "note": None, + "overall_state": "success", + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 0, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 3, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group.subtask", + "try_number": 0, + }, + ], + "version_number": None, + }, + { + "data_interval_end": "2024-11-30T00:00:00Z", + "data_interval_start": "2024-11-29T00:00:00Z", + "end_date": "2024-12-01T00:00:00Z", + "note": None, + "queued_at": None, + "run_id": "run_2", + "run_type": "manual", + "start_date": "2016-01-01T00:00:00Z", + "state": "failed", + "task_instances": [ + { + "end_date": None, + "note": None, + "overall_state": None, + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 3, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 0, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group", + "try_number": 0, + }, + { + "end_date": None, + "note": None, + "overall_state": None, + "queued_dttm": None, + "start_date": None, + "states": { + "deferred": 0, + "failed": 0, + "no_status": 3, + "queued": 0, + "removed": 0, + "restarting": 0, + "running": 0, + "scheduled": 0, + "skipped": 0, + "success": 0, + "up_for_reschedule": 0, + "up_for_retry": 0, + "upstream_failed": 0, + }, + "task_count": 3, + "task_id": "mapped_task_group.subtask", + "try_number": 0, + }, + ], + "version_number": None, + }, + ], + } + def test_should_response_404(self, test_client): response = test_client.get("/ui/grid/invalid_dag") assert response.status_code == 404