Skip to content

Commit

Permalink
Test standard provider with Airflow 2.8 and 2.9
Browse files Browse the repository at this point in the history
The standard provider has now min version of Airflow = 2.8
since #43553, but we have not tested it for Airflow 2.8 and 2.9.
  • Loading branch information
potiuk committed Nov 3, 2024
1 parent ff6038b commit 630eb6f
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 41 deletions.
4 changes: 2 additions & 2 deletions dev/breeze/src/airflow_breeze/global_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,13 +574,13 @@ def get_airflow_extras():
{
"python-version": "3.9",
"airflow-version": "2.8.4",
"remove-providers": "cloudant fab edge standard",
"remove-providers": "cloudant fab edge",
"run-tests": "true",
},
{
"python-version": "3.9",
"airflow-version": "2.9.3",
"remove-providers": "cloudant edge standard",
"remove-providers": "cloudant edge",
"run-tests": "true",
},
{
Expand Down
58 changes: 41 additions & 17 deletions providers/src/airflow/providers/standard/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.typing_compat import Literal
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_get_outlet_events, context_merge
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess, execute_in_subprocess_with_kwargs
from airflow.utils.session import create_session

log = logging.getLogger(__name__)

AIRFLOW_VERSION = Version(airflow_version)
AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0")
AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0")

if TYPE_CHECKING:
Expand Down Expand Up @@ -187,7 +188,15 @@ def __init__(
def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
self._asset_events = context_get_outlet_events(context)

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.context import context_get_outlet_events

self._asset_events = context_get_outlet_events(context)
elif AIRFLOW_V_2_10_PLUS:
from airflow.utils.context import context_get_outlet_events

self._dataset_events = context_get_outlet_events(context)

return_value = self.execute_callable()
if self.show_return_value_in_logs:
Expand All @@ -206,7 +215,15 @@ def execute_callable(self) -> Any:
:return: the return value of the call.
"""
runner = ExecutionCallableRunner(self.python_callable, self._asset_events, logger=self.log)
try:
from airflow.utils.operator_helpers import ExecutionCallableRunner

asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events

runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log)
except ImportError:
# Handle Pre Airflow 3.10 case where ExecutionCallableRunner was not available
return self.python_callable(*self.op_args, **self.op_kwargs)
return runner.run(*self.op_args, **self.op_kwargs)


Expand Down Expand Up @@ -551,18 +568,25 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
env_vars.update(self.env_vars)

try:
execute_in_subprocess(
cmd=[
os.fspath(python_path),
os.fspath(script_path),
os.fspath(input_path),
os.fspath(output_path),
os.fspath(string_args_path),
os.fspath(termination_log_path),
os.fspath(airflow_context_path),
],
env=env_vars,
)
cmd: list[str] = [
os.fspath(python_path),
os.fspath(script_path),
os.fspath(input_path),
os.fspath(output_path),
os.fspath(string_args_path),
os.fspath(termination_log_path),
os.fspath(airflow_context_path),
]
if AIRFLOW_V_2_10_PLUS:
execute_in_subprocess(
cmd=cmd,
env=env_vars,
)
else:
execute_in_subprocess_with_kwargs(
cmd=cmd,
env=env_vars,
)
except subprocess.CalledProcessError as e:
if e.returncode in self.skip_on_exit_code:
raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.")
Expand Down
23 changes: 21 additions & 2 deletions providers/src/airflow/providers/standard/sensors/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,27 @@
from __future__ import annotations

import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NoReturn, Sequence

from airflow.providers.standard.operators.python import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.base import StartTriggerArgs

try:
from airflow.triggers.base import StartTriggerArgs
except ImportError:
# TODO: Remove this when min airflow version is 2.10.0 for standard provider
@dataclass
class StartTriggerArgs: # type: ignore[no-redef]
"""Arguments required for start task execution from triggerer."""

trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
next_kwargs: dict[str, Any] | None = None
timeout: datetime.timedelta | None = None


from airflow.triggers.temporal import DateTimeTrigger
from airflow.utils import timezone

Expand Down Expand Up @@ -125,7 +142,9 @@ def execute(self, context: Context) -> NoReturn:
trigger=DateTimeTrigger(
moment=timezone.parse(self.target_time),
end_from_trigger=self.end_from_trigger,
),
)
if AIRFLOW_V_3_0_PLUS
else DateTimeTrigger(moment=timezone.parse(self.target_time)),
)

def execute_complete(self, context: Context, event: Any = None) -> None:
Expand Down
23 changes: 21 additions & 2 deletions providers/src/airflow/providers/standard/sensors/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,27 @@
from __future__ import annotations

import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NoReturn

from airflow.providers.standard.operators.python import AIRFLOW_V_3_0_PLUS
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.base import StartTriggerArgs

try:
from airflow.triggers.base import StartTriggerArgs
except ImportError:
# TODO: Remove this when min airflow version is 2.10.0 for standard provider
@dataclass
class StartTriggerArgs: # type: ignore[no-redef]
"""Arguments required for start task execution from triggerer."""

trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
next_kwargs: dict[str, Any] | None = None
timeout: datetime.timedelta | None = None


from airflow.triggers.temporal import DateTimeTrigger
from airflow.utils import timezone

Expand Down Expand Up @@ -102,7 +119,9 @@ def __init__(

def execute(self, context: Context) -> NoReturn:
self.defer(
trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger),
trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger)
if AIRFLOW_V_3_0_PLUS
else DateTimeTrigger(moment=self.target_datetime),
method_name="execute_complete",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
from airflow.utils.context import Context

Expand Down Expand Up @@ -81,7 +83,10 @@ def execute(self, context: Context) -> bool | NoReturn:
# If the target datetime is in the past, return immediately
return True
try:
trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger)
if AIRFLOW_V_3_0_PLUS:
trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger)
else:
trigger = DateTimeTrigger(moment=target_dttm)
except (TypeError, ValueError) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
Expand Down
5 changes: 1 addition & 4 deletions providers/tests/openlineage/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,11 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType

from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, BashOperator, PythonOperator
from tests_common.test_utils.compat import BashOperator, PythonOperator
from tests_common.test_utils.mock_operators import MockOperator

BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash"
PYTHON_OPERATOR_PATH = "airflow.providers.standard.operators.python"
if not AIRFLOW_V_2_10_PLUS:
BASH_OPERATOR_PATH = "airflow.operators.bash"
PYTHON_OPERATOR_PATH = "airflow.operators.python"


class CustomOperatorForTest(BashOperator):
Expand Down
40 changes: 27 additions & 13 deletions providers/tests/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow.models.taskinstance import TaskInstance, clear_task_instances, set_current_context
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import (
AIRFLOW_V_2_10_PLUS,
BranchExternalPythonOperator,
BranchPythonOperator,
BranchPythonVirtualenvOperator,
Expand Down Expand Up @@ -509,7 +510,7 @@ def f():
ti = self.create_ti(f)
with pytest.raises(
AirflowException,
match="'branch_task_ids' expected all task IDs are strings.",
match=r"'branch_task_ids'.*task.*",
):
ti.run()

Expand All @@ -518,7 +519,9 @@ def f():
return "some_task_id"

ti = self.create_ti(f)
with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"):
with pytest.raises(
AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*"
):
ti.run()

@pytest.mark.skip_if_database_isolation_mode # tests pure logic with run() method, can not run in isolation mode
Expand Down Expand Up @@ -903,9 +906,12 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance):
"ti",
"var", # Accessor for Variable; var->json and var->value.
"conn", # Accessor for Connection.
"inlet_events", # Accessor for inlet AssetEvent.
"outlet_events", # Accessor for outlet AssetEvent.
]
if AIRFLOW_V_2_10_PLUS:
intentionally_excluded_context_keys.extend(
# Accessors for inlet_events and outlet_events
["inlet_events", "outlet_events"]
)

ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None)
context = ti.get_template_context()
Expand Down Expand Up @@ -1627,21 +1633,25 @@ def f(a, b, c=False, d=False):
else:
raise RuntimeError

with pytest.raises(AirflowException, match=r"Invalid tasks found: {\((True|False), 'bool'\)}"):
with pytest.raises(
AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*"
):
self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})

def test_return_false(self):
def f():
return False

with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."):
with pytest.raises(
AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*"
):
self.run_as_task(f)

def test_context(self):
def f(templates_dict):
return templates_dict["ds"]

with pytest.raises(AirflowException, match="Invalid tasks found:"):
with pytest.raises(AirflowException, match="Invalid tasks found:|'branch_task_ids'.*task.*"):
self.run_as_task(f, templates_dict={"ds": "{{ ds }}"})

def test_environment_variables(self):
Expand All @@ -1652,7 +1662,7 @@ def f():

with pytest.raises(
AirflowException,
match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'ABCDE'}",
match=r"'branch_task_ids'.*task.*",
):
self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"})

Expand All @@ -1666,7 +1676,7 @@ def f():

with pytest.raises(
AirflowException,
match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'QWERT'}",
match=r"'branch_task_ids'.*task.*",
):
self.run_as_task(f, inherit_env=True)

Expand All @@ -1691,7 +1701,7 @@ def f():

with pytest.raises(
AirflowException,
match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'EFGHI'}",
match=r"'branch_task_ids'.*task.*",
):
self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True)

Expand All @@ -1706,7 +1716,9 @@ def test_with_no_caching(self):
def f():
return False

with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."):
with pytest.raises(
AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*"
):
self.run_as_task(f, do_not_use_caching=True)

def test_with_dag_run(self):
Expand Down Expand Up @@ -1827,7 +1839,7 @@ def f():
ti = self.create_ti(f)
with pytest.raises(
AirflowException,
match="'branch_task_ids' expected all task IDs are strings.",
match=r"'branch_task_ids'.*task.*",
):
ti.run()

Expand All @@ -1836,7 +1848,9 @@ def f():
return "some_task_id"

ti = self.create_ti(f)
with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"):
with pytest.raises(
AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*"
):
ti.run()


Expand Down

0 comments on commit 630eb6f

Please sign in to comment.