Skip to content

Commit

Permalink
Examples with taskflow and dynamic task mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Dec 6, 2024
1 parent 2045e48 commit 1879270
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 171 deletions.
249 changes: 90 additions & 159 deletions dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,55 +453,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if task_params.get("init_containers") is not None
else None
)
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
DagBuilder.adjust_general_task_params(task_params)

expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
# expand available only in airflow >= 2.3.0
Expand All @@ -519,23 +471,6 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
task_params.update(partial_kwargs)

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse(
"2.4.0"
):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
Expand Down Expand Up @@ -906,6 +841,69 @@ def topological_sort_tasks(tasks_configs: dict[str, Any]) -> list[tuple(str, Any

return sorted_tasks

def adjust_general_task_params(task_params: dict(str, Any)):
"""Adjusts in place the task params argument"""
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]

if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]

if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]

# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback") and version.parse(AIRFLOW_VERSION) >= version.parse(
"2.0.0"
):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])

if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])

if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])

if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])

# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]

if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

@staticmethod
def make_decorator(
decorator_import_path: str, task_params: Dict[str, Any], tasks_dict: dict(str, Any)
) -> BaseOperator:
Expand Down Expand Up @@ -942,110 +940,43 @@ def make_decorator(
decorator: Callable[..., BaseOperator] = import_string(decorator_import_path)
task_params.pop("decorator")

"""
Things that are handled when we create other tasks, we may want to handle some of this in the decorator as well:
if utils.check_dict_key(task_params, "execution_timeout_secs"):
task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"])
del task_params["execution_timeout_secs"]
if utils.check_dict_key(task_params, "sla_secs"):
task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"])
del task_params["sla_secs"]
if utils.check_dict_key(task_params, "execution_delta_secs"):
task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"])
del task_params["execution_delta_secs"]
DagBuilder.adjust_general_task_params(task_params)

if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key(
task_params, "execution_date_fn_file"
):
task_params["execution_date_fn"]: Callable = utils.get_python_callable(
task_params["execution_date_fn_name"],
task_params["execution_date_fn_file"],
)
del task_params["execution_date_fn_name"]
del task_params["execution_date_fn_file"]
# on_execute_callback is an Airflow 2.0 feature
if utils.check_dict_key(task_params, "on_execute_callback") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.0.0"):
task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"])
if utils.check_dict_key(task_params, "on_failure_callback"):
task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"])
if utils.check_dict_key(task_params, "on_success_callback"):
task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"])
if utils.check_dict_key(task_params, "on_retry_callback"):
task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"])
# use variables as arguments on operator
if utils.check_dict_key(task_params, "variables_as_arguments"):
variables: List[Dict[str, str]] = task_params.get("variables_as_arguments")
for variable in variables:
if Variable.get(variable["variable"], default_var=None) is not None:
task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None)
del task_params["variables_as_arguments"]
if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"):
raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0")
expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {}
# expand available only in airflow >= 2.3.0
if (
utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial")
) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"):
# Getting expand and partial kwargs from task_params
(
task_params,
expand_kwargs,
partial_kwargs,
) = utils.get_expand_partial_kwargs(task_params)
# If there are partial_kwargs we should merge them with existing task_params
if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params):
task_params.update(partial_kwargs)
if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse(
"2.4.0"
):
if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key(
task_params["outlets"], "datasets"
):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)
del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]
task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]
task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
else operator_obj.partial(**task_params).expand(**expand_kwargs)
)
"""
callable_args_keys = inspect.getfullargspec(python_callable).args
callable_kwargs = {}
decorator_kwargs = dict(**task_params)
for arg_key, arg_value in task_params.items():
if arg_key in callable_args_keys:
decorator_kwargs.pop(arg_key)
if arg_value.startswith("+"):
upstream_task_name = arg_value.split("+")[0]
if isinstance(arg_value, str) and arg_value.startswith("+"):
upstream_task_name = arg_value.split("+")[-1]
callable_kwargs[arg_key] = tasks_dict[upstream_task_name]
else:
callable_kwargs[arg_key] = arg_value

return decorator(**decorator_kwargs)(**callable_kwargs)
expand_kwargs = decorator_kwargs.pop("expand", {})
partial_kwargs = decorator_kwargs.pop("partial", {})

if expand_kwargs and partial_kwargs:
if callable_kwargs:
raise DagFactoryConfigException(
"When using dynamic task mapping, all the task arguments should be defined in expand and partial."
)
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
DagBuilder.replace_kwargs_values_as_tasks(partial_kwargs, tasks_dict)
return decorator(**decorator_kwargs).partial(**partial_kwargs).expand(**expand_kwargs)
elif expand_kwargs:
DagBuilder.replace_kwargs_values_as_tasks(expand_kwargs, tasks_dict)
return decorator(**decorator_kwargs).expand(**expand_kwargs)
else:
return decorator(**decorator_kwargs)(**callable_kwargs)

@staticmethod
def replace_kwargs_values_as_tasks(kwargs: dict(str, Any), tasks_dict: dict(str, Any)):
for key, value in kwargs.items():
if isinstance(value, str) and value.startswith("+"):
upstream_task_name = value.split("+")[-1]
kwargs[key] = tasks_dict[upstream_task_name]

@staticmethod
def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable:
Expand Down
48 changes: 39 additions & 9 deletions dev/dags/example_taskflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,49 @@ example_taskflow:
decorator: airflow.decorators.task
python_callable_name: build_numbers_list
python_callable_file: $CONFIG_ROOT_DIR/sample.py
another_numbers_list:
decorator: airflow.decorators.task
python_callable: sample.build_numbers_list
double_number_from_arg:
decorator: airflow.decorators.task
python_callable: sample.double
number: 2
double_number_from_task:
decorator: airflow.decorators.task
python_callable: sample.double
number: +some_number # this is a task previously defined
#dependencies: [some_number]
#double_list_dynamically:
# python_callable: !!python/name:sample.double"
# python_callable_file: $CONFIG_ROOT_DIR/sample.py
# map_index_template: "{{ my_param }}" # https://github.com/astronomer/dag-factory/issues/302
# expand: # should always receive kwargs
# value: first_list
# dependencies: [first_list]
number: +some_number # the prefix + leads to resolving this value as the task `some_number`, previously defined
double_number_with_dynamic_task_mapping_static:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number:
- 1
- 3
- 5
double_number_with_dynamic_task_mapping_taskflow:
decorator: airflow.decorators.task
python_callable: sample.double
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
multiply_with_multiple_parameters:
decorator: airflow.decorators.task
python_callable: sample.multiply
expand:
a: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
b: +another_numbers_list # the prefix + tells DagFactory to resolve this value as the task `another_numbers_list`, previously defined
double_number_with_dynamic_task_and_partial:
decorator: airflow.decorators.task
python_callable: sample.double_with_label
expand:
number: +numbers_list # the prefix + tells DagFactory to resolve this value as the task `numbers_list`, previously defined
partial:
label: True
dynamic_task_with_named_mapping:
decorator: airflow.decorators.task
python_callable: sample.extract_last_name
map_index_template: "{{ custom_mapping_key }}"
expand:
full_name:
- Lucy Black
- Vera Santos
- Marks Spencer
38 changes: 35 additions & 3 deletions dev/dags/sample.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
from random import randint

from airflow.operators.python import get_current_context


def build_numbers_list():
return [2, 4, 8]
return [2, 4, 6]


def some_number():
return randint(0, 100)


def double(number):
return 2 * number
def double(number: int):
result = 2 * number
print(result)
return result


def multiply(a: int, b: int) -> int:
result = a * b
print(result)
return result


# added_values = add.expand(x=first_list(), y=second_list())


def double_with_label(number: int, label: bool = False):
result = 2 * number
if not label:
print(result)
return result
else:
label_info = "even" if number % 2 else "odd"
print(f"{result} is {label_info}")
return result, label_info


def extract_last_name(full_name: str):
name, last_name = full_name.split(" ")
print(f"{name} {last_name}")
context = get_current_context()
context["custom_mapping_key"] = name
return last_name

0 comments on commit 1879270

Please sign in to comment.