diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index b58f9d7c..fda98e8e 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -1,5 +1,7 @@ """Module contains code for generating tasks and constructing a DAG""" +from __future__ import annotations + # pylint: disable=ungrouped-imports import os import re @@ -100,7 +102,7 @@ # these are params only used in the DAG factory, not in the tasks -SYSTEM_PARAMS: List[str] = ["operator", "dependencies", "task_group_name"] +SYSTEM_PARAMS: List[str] = ["operator", "dependencies", "task_group_name", "parent_group_name"] class DagBuilder: @@ -549,52 +551,82 @@ def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGr task_groups_dict: Dict[str, "TaskGroup"] = {} if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): for task_group_name, task_group_conf in task_groups.items(): - task_group_conf["group_id"] = task_group_name - task_group_conf["dag"] = dag + DagBuilder.make_nested_task_groups( + task_group_name, task_group_conf, task_groups_dict, task_groups, None, dag + ) - if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0") and isinstance( - task_group_conf.get("default_args"), dict - ): - # https://github.com/apache/airflow/pull/16557 - if utils.check_dict_key(task_group_conf["default_args"], "on_success_callback"): - if isinstance( - task_group_conf["default_args"]["on_success_callback"], - str, - ): - task_group_conf["default_args"]["on_success_callback"]: Callable = import_string( - task_group_conf["default_args"]["on_success_callback"] - ) + return task_groups_dict - if utils.check_dict_key(task_group_conf["default_args"], "on_execute_callback"): - if isinstance( - task_group_conf["default_args"]["on_execute_callback"], - str, - ): - task_group_conf["default_args"]["on_execute_callback"]: Callable = import_string( - task_group_conf["default_args"]["on_execute_callback"] - ) + @staticmethod + def _init_task_group_callback_param(task_group_conf): + if not ( + version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0") + and isinstance(task_group_conf.get("default_args"), dict) + ): + return task_group_conf - if utils.check_dict_key(task_group_conf["default_args"], "on_failure_callback"): - if isinstance( - task_group_conf["default_args"]["on_failure_callback"], - str, - ): - task_group_conf["default_args"]["on_failure_callback"]: Callable = import_string( - task_group_conf["default_args"]["on_failure_callback"] - ) + default_args = task_group_conf["default_args"] + callback_keys = [ + "on_success_callback", + "on_execute_callback", + "on_failure_callback", + "on_retry_callback", + ] - if utils.check_dict_key(task_group_conf["default_args"], "on_retry_callback"): - if isinstance( - task_group_conf["default_args"]["on_retry_callback"], - str, - ): - task_group_conf["default_args"]["on_retry_callback"]: Callable = import_string( - task_group_conf["default_args"]["on_retry_callback"] - ) + for key in callback_keys: + if key in default_args and isinstance(default_args[key], str): + default_args[key]: Callable = import_string(default_args[key]) - task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS}) - task_groups_dict[task_group.group_id] = task_group - return task_groups_dict + return task_group_conf + + @staticmethod + def make_nested_task_groups( + task_group_name: str, + task_group_conf: Any, + task_groups_dict: Dict[str, "TaskGroup"], + task_groups: Dict[str, Any], + circularity_check_queue: List[str] | None, + dag: DAG, + ): + """Takes a DAG and task group configurations. Creates nested TaskGroup instances. + :param task_group_name: The name of the task group to be created + :param task_group_conf: Configuration details for the task group, which may include parent group information. + :param task_groups_dict: A dictionary where the created TaskGroup instances are stored, keyed by task group name. + :param task_groups: Task group configuration from the YAML configuration file. + :param circularity_check_queue: A list used to track the task groups being processed to detect circular dependencies. + :param dag: DAG instance that task groups to be added. + """ + if task_group_name in task_groups_dict: + return + + if circularity_check_queue is None: + circularity_check_queue = [] + + if task_group_name in circularity_check_queue: + error_string = "Circular dependency detected:\n" + index = circularity_check_queue.index(task_group_name) + while index < len(circularity_check_queue): + error_string += f"{circularity_check_queue[index]} depends on {task_group_name}\n" + index += 1 + raise Exception(error_string) + + circularity_check_queue.append(task_group_name) + + if task_group_conf.get("parent_group_name"): + parent_group_name = task_group_conf["parent_group_name"] + parent_group_conf = task_groups[parent_group_name] + DagBuilder.make_nested_task_groups( + parent_group_name, parent_group_conf, task_groups_dict, task_groups, circularity_check_queue, dag + ) + task_group_conf["parent_group"] = task_groups_dict[parent_group_name] + + task_group_conf["group_id"] = task_group_name + task_group_conf["dag"] = dag + + task_group_conf = DagBuilder._init_task_group_callback_param(task_group_conf) + + task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS}) + task_groups_dict[task_group_name] = task_group @staticmethod def set_dependencies( diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 7081e360..1b706a2f 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -301,7 +301,7 @@ Sample, multi-line callback text. """, "channel": "#channel", - "username": "username" + "username": "username", }, }, "description": "this is an example dag", @@ -1033,3 +1033,28 @@ def test_make_task_outlets(mock_read_file, outlets, output): operator = "airflow.operators.python_operator.PythonOperator" actual = td.make_task(operator, task_params) assert actual.outlets == [Dataset(uri) for uri in output] + + +@patch("dagfactory.dagbuilder.TaskGroup", new=MockTaskGroup) +def test_make_nested_task_groups(): + task_group_dict = { + "task_group": { + "tooltip": "this is a task group", + }, + "sub_task_group": {"tooltip": "this is a sub task group", "parent_group_name": "task_group"}, + } + dag = "dag" + task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag) + expected = { + "task_group": MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag), + "sub_task_group": MockTaskGroup(tooltip="this is a sub task group", group_id="sub_task_group", dag=dag), + } + + if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): + assert task_groups == {} + else: + sub_task_group = task_groups["sub_task_group"].__dict__ + assert sub_task_group["parent_group"] + del sub_task_group["parent_group"] + assert task_groups["task_group"].__dict__ == expected["task_group"].__dict__ + assert sub_task_group == expected["sub_task_group"].__dict__