Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for serverless job in Databricks operators #45188

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Currently the named parameters that ``DatabricksCreateJobsOperator`` supports ar
- ``max_concurrent_runs``
- ``git_source``
- ``access_control_list``
- ``databricks_environments``


Examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are
- ``libraries``
- ``run_name``
- ``timeout_seconds``
- ``databricks_environments``

.. code-block:: python

Expand Down
8 changes: 8 additions & 0 deletions docs/apache-airflow-providers-databricks/operators/task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ Running a SQL query in Databricks using DatabricksTaskOperator
:language: python
:start-after: [START howto_operator_databricks_task_sql]
:end-before: [END howto_operator_databricks_task_sql]

Running a python file in Databricks in using DatabricksTaskOperator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. exampleinclude:: /../../providers/tests/system/databricks/example_databricks.py
:language: python
:start-after: [START howto_operator_databricks_task_python]
:end-before: [END howto_operator_databricks_task_python]

61 changes: 40 additions & 21 deletions providers/src/airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class DatabricksCreateJobsOperator(BaseOperator):
:param databricks_retry_delay: Number of seconds to wait between retries (it
might be a floating point number).
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job.

"""

Expand Down Expand Up @@ -323,6 +325,7 @@ def __init__(
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
databricks_retry_args: dict[Any, Any] | None = None,
databricks_environments: list[dict] | None = None,
**kwargs,
) -> None:
"""Create a new ``DatabricksCreateJobsOperator``."""
Expand Down Expand Up @@ -359,6 +362,8 @@ def __init__(
self.json["git_source"] = git_source
if access_control_list is not None:
self.json["access_control_list"] = access_control_list
if databricks_environments is not None:
self.json["environments"] = databricks_environments
if self.json:
self.json = normalise_json_content(self.json)

Expand Down Expand Up @@ -502,6 +507,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param git_source: Optional specification of a remote git repository from which
supported task types are retrieved.
:param deferrable: Run operator in the deferrable mode.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job.

.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
Expand Down Expand Up @@ -542,6 +549,7 @@ def __init__(
wait_for_termination: bool = True,
git_source: dict[str, str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
databricks_environments: list[dict] | None = None,
**kwargs,
) -> None:
"""Create a new ``DatabricksSubmitRunOperator``."""
Expand Down Expand Up @@ -586,6 +594,8 @@ def __init__(
self.json["access_control_list"] = access_control_list
if git_source is not None:
self.json["git_source"] = git_source
if databricks_environments is not None:
self.json["environments"] = databricks_environments

if "dbt_task" in self.json and "git_source" not in self.json:
raise AirflowException("git_source is required for dbt_task")
Expand Down Expand Up @@ -980,6 +990,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC):
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within
a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job.
"""

def __init__(
Expand All @@ -996,6 +1008,7 @@ def __init__(
polling_period_seconds: int = 5,
wait_for_termination: bool = True,
workflow_run_metadata: dict[str, Any] | None = None,
databricks_environments: list[dict] | None = None,
**kwargs: Any,
):
self.caller = caller
Expand All @@ -1010,7 +1023,7 @@ def __init__(
self.polling_period_seconds = polling_period_seconds
self.wait_for_termination = wait_for_termination
self.workflow_run_metadata = workflow_run_metadata

self.databricks_environments = databricks_environments
self.databricks_run_id: int | None = None

super().__init__(**kwargs)
Expand Down Expand Up @@ -1087,7 +1100,9 @@ def _get_run_json(self) -> dict[str, Any]:
elif self.existing_cluster_id:
run_json["existing_cluster_id"] = self.existing_cluster_id
else:
raise ValueError("Must specify either existing_cluster_id or new_cluster.")
self.log.info("The task %s will be executed in serverless mode", run_json["run_name"])
if self.databricks_environments:
run_json["environments"] = self.databricks_environments
return run_json

def _launch_job(self, context: Context | None = None) -> int:
Expand Down Expand Up @@ -1372,27 +1387,29 @@ def _convert_to_databricks_workflow_task(

class DatabricksTaskOperator(DatabricksTaskBaseOperator):
"""
Runs a task on Databricks using an Airflow operator.

The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow
tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which
allows users to run their tasks on cheaper clusters that can be shared between tasks.
Runs a task on Databricks using an Airflow operator.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DatabricksTaskOperator`
The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow
tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which
allows users to run their tasks on cheaper clusters that can be shared between tasks.

:param task_config: The configuration of the task to be run on Databricks.
:param databricks_conn_id: The name of the Airflow connection to use.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
:param deferrable: Whether to run the operator in the deferrable mode.
:param existing_cluster_id: ID for existing cluster on which to run this task.
:param job_cluster_key: The key for the job cluster.
:param new_cluster: Specs for a new cluster on which this task will be run.
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DatabricksTaskOperator`

:param task_config: The configuration of the task to be run on Databricks.
:param databricks_conn_id: The name of the Airflow connection to use.
:param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
:param databricks_retry_delay: Number of seconds to wait between retries.
:param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable.
:param deferrable: Whether to run the operator in the deferrable mode.
:param existing_cluster_id: ID for existing cluster on which to run this task.
:param job_cluster_key: The key for the job cluster.
:param new_cluster: Specs for a new cluster on which this task will be run.
:param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run.
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job
"""

CALLER = "DatabricksTaskOperator"
Expand All @@ -1412,6 +1429,7 @@ def __init__(
polling_period_seconds: int = 5,
wait_for_termination: bool = True,
workflow_run_metadata: dict | None = None,
databricks_environments: list[dict] | None = None,
**kwargs,
):
self.task_config = task_config
Expand All @@ -1429,6 +1447,7 @@ def __init__(
polling_period_seconds=polling_period_seconds,
wait_for_termination=wait_for_termination,
workflow_run_metadata=workflow_run_metadata,
databricks_environments=databricks_environments,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class _CreateDatabricksWorkflowOperator(BaseOperator):
will be passed to all notebooks in the workflow.
:param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be
populated after instantiation using the `add_task` method.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job.
"""

operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink())
Expand All @@ -106,6 +108,7 @@ def __init__(
max_concurrent_runs: int = 1,
notebook_params: dict | None = None,
tasks_to_convert: list[BaseOperator] | None = None,
databricks_environments: list[dict] | None = None,
**kwargs,
):
self.databricks_conn_id = databricks_conn_id
Expand All @@ -117,6 +120,7 @@ def __init__(
self.tasks_to_convert = tasks_to_convert or []
self.relevant_upstreams = [task_id]
self.workflow_run_metadata: WorkflowRunMetadata | None = None
self.databricks_environments = databricks_environments
super().__init__(task_id=task_id, **kwargs)

def _get_hook(self, caller: str) -> DatabricksHook:
Expand Down Expand Up @@ -156,6 +160,7 @@ def create_workflow_json(self, context: Context | None = None) -> dict[str, obje
"format": "MULTI_TASK",
"job_clusters": self.job_clusters,
"max_concurrent_runs": self.max_concurrent_runs,
"environments": self.databricks_environments,
}
return merge(default_json, self.extra_job_params)

Expand Down Expand Up @@ -274,6 +279,8 @@ class DatabricksWorkflowTaskGroup(TaskGroup):
all python tasks in the workflow.
:param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters
will be passed to all spark submit tasks.
:param databricks_environments: An optional list of task execution environment specifications
that can be referenced by serverless tasks of this job.
"""

is_databricks = True
Expand All @@ -290,6 +297,7 @@ def __init__(
notebook_params: dict | None = None,
python_params: list | None = None,
spark_submit_params: list | None = None,
databricks_environments: list[dict] | None = None,
**kwargs,
):
self.databricks_conn_id = databricks_conn_id
Expand All @@ -302,6 +310,7 @@ def __init__(
self.notebook_params = notebook_params or {}
self.python_params = python_params or []
self.spark_submit_params = spark_submit_params or []
self.databricks_environments = databricks_environments or []
super().__init__(**kwargs)

def __exit__(
Expand All @@ -321,6 +330,7 @@ def __exit__(
job_clusters=self.job_clusters,
max_concurrent_runs=self.max_concurrent_runs,
notebook_params=self.notebook_params,
databricks_environments=self.databricks_environments,
)

for task in tasks:
Expand Down
Loading