Skip to content

Commit

Permalink
Add support for tested task group @glazunov996 (#292)
Browse files Browse the repository at this point in the history
Resolves Nested Subgroups
#119 issue

Original PR: #141

Co-authored-by: Kuris <[email protected]>

closes: #119

Example YML DAG
```
default:
  default_args:
    owner: default_owner
    retries: 1
    retry_delay_sec: 300
    start_date: 2024-01-01
  default_view: tree
  max_active_runs: 1
  schedule_interval: 0 1 * * *
example_task_group:
  description: "this dag uses task groups"
  task_groups:
    task_group_1:
      tooltip: "this is a task group"
      dependencies: [task_1]
    task_group_2:
      tooltip: "this is a task group"
      parent_group_name: task_group_1
  tasks:
    task_1:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 1"
    task_2:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 2"
      task_group_name: task_group_1
    task_4:
      operator: airflow.operators.bash_operator.BashOperator
      bash_command: "echo 4"
      task_group_name: task_group_2

```
  • Loading branch information
pankajastro authored Nov 20, 2024
1 parent 984e8cd commit 7626fa6
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 43 deletions.
116 changes: 74 additions & 42 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module contains code for generating tasks and constructing a DAG"""

from __future__ import annotations

# pylint: disable=ungrouped-imports
import os
import re
Expand Down Expand Up @@ -100,7 +102,7 @@


# these are params only used in the DAG factory, not in the tasks
SYSTEM_PARAMS: List[str] = ["operator", "dependencies", "task_group_name"]
SYSTEM_PARAMS: List[str] = ["operator", "dependencies", "task_group_name", "parent_group_name"]


class DagBuilder:
Expand Down Expand Up @@ -548,52 +550,82 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr
task_groups_dict: Dict[str, "TaskGroup"] = {}
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
for task_group_name, task_group_conf in task_groups.items():
task_group_conf["group_id"] = task_group_name
task_group_conf["dag"] = dag
DagBuilder.make_nested_task_groups(
task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag
)

if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0") and isinstance(
task_group_conf.get("default_args"), dict
):
# https://github.com/apache/airflow/pull/16557
if utils.check_dict_key(task_group_conf["default_args"], "on_success_callback"):
if isinstance(
task_group_conf["default_args"]["on_success_callback"],
str,
):
task_group_conf["default_args"]["on_success_callback"]: Callable = import_string(
task_group_conf["default_args"]["on_success_callback"]
)
return task_groups_dict

if utils.check_dict_key(task_group_conf["default_args"], "on_execute_callback"):
if isinstance(
task_group_conf["default_args"]["on_execute_callback"],
str,
):
task_group_conf["default_args"]["on_execute_callback"]: Callable = import_string(
task_group_conf["default_args"]["on_execute_callback"]
)
@staticmethod
def _init_task_group_callback_param(task_group_conf):
if not (
version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0")
and isinstance(task_group_conf.get("default_args"), dict)
):
return task_group_conf

if utils.check_dict_key(task_group_conf["default_args"], "on_failure_callback"):
if isinstance(
task_group_conf["default_args"]["on_failure_callback"],
str,
):
task_group_conf["default_args"]["on_failure_callback"]: Callable = import_string(
task_group_conf["default_args"]["on_failure_callback"]
)
default_args = task_group_conf["default_args"]
callback_keys = [
"on_success_callback",
"on_execute_callback",
"on_failure_callback",
"on_retry_callback",
]

if utils.check_dict_key(task_group_conf["default_args"], "on_retry_callback"):
if isinstance(
task_group_conf["default_args"]["on_retry_callback"],
str,
):
task_group_conf["default_args"]["on_retry_callback"]: Callable = import_string(
task_group_conf["default_args"]["on_retry_callback"]
)
for key in callback_keys:
if key in default_args and isinstance(default_args[key], str):
default_args[key]: Callable = import_string(default_args[key])

task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS})
task_groups_dict[task_group.group_id] = task_group
return task_groups_dict
return task_group_conf

@staticmethod
def make_nested_task_groups(
task_group_name: str,
task_group_conf: Any,
task_groups_dict: Dict[str, "TaskGroup"],
task_groups: Dict[str, Any],
circularity_check_queue: List[str] | None,
dag: DAG,
):
"""Takes a DAG and task group configurations. Creates nested TaskGroup instances.
:param task_group_name: The name of the task group to be created
:param task_group_conf: Configuration details for the task group, which may include parent group information.
:param task_groups_dict: A dictionary where the created TaskGroup instances are stored, keyed by task group name.
:param task_groups: Task group configuration from the YAML configuration file.
:param circularity_check_queue: A list used to track the task groups being processed to detect circular dependencies.
:param dag: DAG instance that task groups to be added.
"""
if task_group_name in task_groups_dict:
return

if circularity_check_queue is None:
circularity_check_queue = []

if task_group_name in circularity_check_queue:
error_string = "Circular dependency detected:\n"
index = circularity_check_queue.index(task_group_name)
while index < len(circularity_check_queue):
error_string += f"{circularity_check_queue[index]} depends on {task_group_name}\n"
index += 1
raise Exception(error_string)

circularity_check_queue.append(task_group_name)

if task_group_conf.get("parent_group_name"):
parent_group_name = task_group_conf["parent_group_name"]
parent_group_conf = task_groups[parent_group_name]
DagBuilder.make_nested_task_groups(
parent_group_name, parent_group_conf, task_groups_dict, task_groups, circularity_check_queue, dag
)
task_group_conf["parent_group"] = task_groups_dict[parent_group_name]

task_group_conf["group_id"] = task_group_name
task_group_conf["dag"] = dag

task_group_conf = DagBuilder._init_task_group_callback_param(task_group_conf)

task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS})
task_groups_dict[task_group_name] = task_group

@staticmethod
def set_dependencies(
Expand Down
27 changes: 26 additions & 1 deletion tests/test_dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
Sample, multi-line callback text.
""",
"channel": "#channel",
"username": "username"
"username": "username",
},
},
"description": "this is an example dag",
Expand Down Expand Up @@ -1033,3 +1033,28 @@ def test_make_task_outlets(mock_read_file, outlets, output):
operator = "airflow.operators.python_operator.PythonOperator"
actual = td.make_task(operator, task_params)
assert actual.outlets == [Dataset(uri) for uri in output]


@patch("dagfactory.dagbuilder.TaskGroup", new=MockTaskGroup)
def test_make_nested_task_groups():
task_group_dict = {
"task_group": {
"tooltip": "this is a task group",
},
"sub_task_group": {"tooltip": "this is a sub task group", "parent_group_name": "task_group"},
}
dag = "dag"
task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag)
expected = {
"task_group": MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag),
"sub_task_group": MockTaskGroup(tooltip="this is a sub task group", group_id="sub_task_group", dag=dag),
}

if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
assert task_groups == {}
else:
sub_task_group = task_groups["sub_task_group"].__dict__
assert sub_task_group["parent_group"]
del sub_task_group["parent_group"]
assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__
assert sub_task_group == expected["sub_task_group"].__dict__

0 comments on commit 7626fa6

Please sign in to comment.