From a0c7fe735b14155c061f3975134fdb076a271411 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sat, 21 Dec 2024 23:12:46 +0000 Subject: [PATCH 1/9] test first commit --- .../src/airflow/providers/databricks/operators/databricks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py b/providers/src/airflow/providers/databricks/operators/databricks.py index 1b8d45fa479a3..e62caeb40ab57 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/src/airflow/providers/databricks/operators/databricks.py @@ -1254,6 +1254,7 @@ class DatabricksNotebookOperator(DatabricksTaskBaseOperator): :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow. + :param is_serverless: flag to run job as a serverless job default to false. """ template_fields = ( From 29b485e0b90941bccc755d5fd8d33974d3387688 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sun, 22 Dec 2024 17:46:35 +0000 Subject: [PATCH 2/9] added environment param and replaced error message when both job_cluster_key and existing_cluster_id are not present, with info that the task will run in serverless mode --- .../databricks/operators/databricks.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py b/providers/src/airflow/providers/databricks/operators/databricks.py index e62caeb40ab57..6fe22cf139d84 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/src/airflow/providers/databricks/operators/databricks.py @@ -292,6 +292,8 @@ class DatabricksCreateJobsOperator(BaseOperator): :param databricks_retry_delay: Number of seconds to wait between retries (it might be a floating point number). :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. """ @@ -323,6 +325,7 @@ def __init__( databricks_retry_limit: int = 3, databricks_retry_delay: int = 1, databricks_retry_args: dict[Any, Any] | None = None, + databricks_environments: list[dict] | None = None, **kwargs, ) -> None: """Create a new ``DatabricksCreateJobsOperator``.""" @@ -359,6 +362,8 @@ def __init__( self.json["git_source"] = git_source if access_control_list is not None: self.json["access_control_list"] = access_control_list + if databricks_environments is not None: + self.json["environments"] = databricks_environments if self.json: self.json = normalise_json_content(self.json) @@ -502,6 +507,8 @@ class DatabricksSubmitRunOperator(BaseOperator): :param git_source: Optional specification of a remote git repository from which supported task types are retrieved. :param deferrable: Run operator in the deferrable mode. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. .. seealso:: https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit @@ -542,6 +549,7 @@ def __init__( wait_for_termination: bool = True, git_source: dict[str, str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + databricks_environments: list[dict] | None = None, **kwargs, ) -> None: """Create a new ``DatabricksSubmitRunOperator``.""" @@ -586,6 +594,8 @@ def __init__( self.json["access_control_list"] = access_control_list if git_source is not None: self.json["git_source"] = git_source + if databricks_environments is not None: + self.json["environments"] = databricks_environments if "dbt_task" in self.json and "git_source" not in self.json: raise AirflowException("git_source is required for dbt_task") @@ -980,6 +990,8 @@ class DatabricksTaskBaseOperator(BaseOperator, ABC): :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. """ def __init__( @@ -996,6 +1008,7 @@ def __init__( polling_period_seconds: int = 5, wait_for_termination: bool = True, workflow_run_metadata: dict[str, Any] | None = None, + databricks_environments: list[dict] | None = None, **kwargs: Any, ): self.caller = caller @@ -1010,7 +1023,7 @@ def __init__( self.polling_period_seconds = polling_period_seconds self.wait_for_termination = wait_for_termination self.workflow_run_metadata = workflow_run_metadata - + self.databricks_environments = databricks_environments self.databricks_run_id: int | None = None super().__init__(**kwargs) @@ -1087,7 +1100,9 @@ def _get_run_json(self) -> dict[str, Any]: elif self.existing_cluster_id: run_json["existing_cluster_id"] = self.existing_cluster_id else: - raise ValueError("Must specify either existing_cluster_id or new_cluster.") + log.info("The task %s will be executed in serverless mode", run_json["run_name"]) + if self.databricks_environments: + run_json["environments"] = self.databricks_environments return run_json def _launch_job(self, context: Context | None = None) -> int: @@ -1147,9 +1162,7 @@ def _convert_to_databricks_workflow_task( } if self.existing_cluster_id and self.job_cluster_key: - raise ValueError( - "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task." - ) + log.info("The task %s will be executed in serverless mode", result["task_key"]) if self.existing_cluster_id: result["existing_cluster_id"] = self.existing_cluster_id elif self.job_cluster_key: @@ -1254,7 +1267,6 @@ class DatabricksNotebookOperator(DatabricksTaskBaseOperator): :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param workflow_run_metadata: Metadata for the workflow run. This is used when the operator is used within a workflow. It is expected to be a dictionary containing the run_id and conn_id for the workflow. - :param is_serverless: flag to run job as a serverless job default to false. """ template_fields = ( @@ -1286,6 +1298,7 @@ def __init__( self.source = source self.notebook_packages = notebook_packages or [] self.notebook_params = notebook_params or {} + self.environment_key = environment_key or "" super().__init__( caller=self.CALLER, @@ -1394,6 +1407,8 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator): :param new_cluster: Specs for a new cluster on which this task will be run. :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. """ CALLER = "DatabricksTaskOperator" @@ -1413,6 +1428,7 @@ def __init__( polling_period_seconds: int = 5, wait_for_termination: bool = True, workflow_run_metadata: dict | None = None, + databricks_environments: list[dict] | None = None, **kwargs, ): self.task_config = task_config @@ -1430,6 +1446,7 @@ def __init__( polling_period_seconds=polling_period_seconds, wait_for_termination=wait_for_termination, workflow_run_metadata=workflow_run_metadata, + databricks_environments=databricks_environments **kwargs, ) From fa7cabe7d6f141c7b34806868d9a850bea347892 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sun, 22 Dec 2024 18:33:49 +0000 Subject: [PATCH 3/9] added databricks_environments in databricks_workflow.py --- .../databricks/operators/databricks_workflow.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py index 6df8e2d025cea..7f2096f6a35b3 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -90,6 +90,8 @@ class _CreateDatabricksWorkflowOperator(BaseOperator): will be passed to all notebooks in the workflow. :param tasks_to_convert: A list of tasks to convert to a Databricks workflow. This list can also be populated after instantiation using the `add_task` method. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. """ operator_extra_links = (WorkflowJobRunLink(), WorkflowJobRepairAllFailedLink()) @@ -106,6 +108,7 @@ def __init__( max_concurrent_runs: int = 1, notebook_params: dict | None = None, tasks_to_convert: list[BaseOperator] | None = None, + databricks_environments: list[dict] | None = None, **kwargs, ): self.databricks_conn_id = databricks_conn_id @@ -117,6 +120,7 @@ def __init__( self.tasks_to_convert = tasks_to_convert or [] self.relevant_upstreams = [task_id] self.workflow_run_metadata: WorkflowRunMetadata | None = None + self.databricks_environments = databricks_environments super().__init__(task_id=task_id, **kwargs) def _get_hook(self, caller: str) -> DatabricksHook: @@ -156,6 +160,7 @@ def create_workflow_json(self, context: Context | None = None) -> dict[str, obje "format": "MULTI_TASK", "job_clusters": self.job_clusters, "max_concurrent_runs": self.max_concurrent_runs, + "environments": self.databricks_environments, } return merge(default_json, self.extra_job_params) @@ -274,6 +279,8 @@ class DatabricksWorkflowTaskGroup(TaskGroup): all python tasks in the workflow. :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters will be passed to all spark submit tasks. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job. """ is_databricks = True @@ -290,6 +297,7 @@ def __init__( notebook_params: dict | None = None, python_params: list | None = None, spark_submit_params: list | None = None, + databricks_environments: list[dict] | None = None, **kwargs, ): self.databricks_conn_id = databricks_conn_id @@ -302,6 +310,7 @@ def __init__( self.notebook_params = notebook_params or {} self.python_params = python_params or [] self.spark_submit_params = spark_submit_params or [] + self.databricks_environments = databricks_environments or [] super().__init__(**kwargs) def __exit__( @@ -321,6 +330,7 @@ def __exit__( job_clusters=self.job_clusters, max_concurrent_runs=self.max_concurrent_runs, notebook_params=self.notebook_params, + databricks_environments=self.databricks_environments, ) for task in tasks: From 80c1804e370be4f35a003a9d17fd1844e0be2ccf Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sun, 22 Dec 2024 19:36:31 +0000 Subject: [PATCH 4/9] added databricks_environments in databricks_workflow.py --- .../src/airflow/providers/databricks/operators/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py b/providers/src/airflow/providers/databricks/operators/databricks.py index 6fe22cf139d84..69d966b560ef7 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/src/airflow/providers/databricks/operators/databricks.py @@ -1408,7 +1408,7 @@ class DatabricksTaskOperator(DatabricksTaskBaseOperator): :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param databricks_environments: An optional list of task execution environment specifications - that can be referenced by serverless tasks of this job. +that can be referenced by serverless tasks of this job """ CALLER = "DatabricksTaskOperator" @@ -1446,7 +1446,7 @@ def __init__( polling_period_seconds=polling_period_seconds, wait_for_termination=wait_for_termination, workflow_run_metadata=workflow_run_metadata, - databricks_environments=databricks_environments + databricks_environments=databricks_environments, **kwargs, ) From 95f1818bb2949060b0050d43d347b4f4f84686b8 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sun, 22 Dec 2024 19:43:46 +0000 Subject: [PATCH 5/9] ruff format --- .../databricks/operators/databricks.py | 47 +++++++++---------- .../operators/databricks_workflow.py | 2 +- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py b/providers/src/airflow/providers/databricks/operators/databricks.py index 69d966b560ef7..08a7e7541b74a 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/src/airflow/providers/databricks/operators/databricks.py @@ -1100,7 +1100,7 @@ def _get_run_json(self) -> dict[str, Any]: elif self.existing_cluster_id: run_json["existing_cluster_id"] = self.existing_cluster_id else: - log.info("The task %s will be executed in serverless mode", run_json["run_name"]) + self.log.info("The task %s will be executed in serverless mode", run_json["run_name"]) if self.databricks_environments: run_json["environments"] = self.databricks_environments return run_json @@ -1162,7 +1162,7 @@ def _convert_to_databricks_workflow_task( } if self.existing_cluster_id and self.job_cluster_key: - log.info("The task %s will be executed in serverless mode", result["task_key"]) + self.log.info("The task %s will be executed in serverless mode", result["task_key"]) if self.existing_cluster_id: result["existing_cluster_id"] = self.existing_cluster_id elif self.job_cluster_key: @@ -1298,7 +1298,6 @@ def __init__( self.source = source self.notebook_packages = notebook_packages or [] self.notebook_params = notebook_params or {} - self.environment_key = environment_key or "" super().__init__( caller=self.CALLER, @@ -1386,29 +1385,29 @@ def _convert_to_databricks_workflow_task( class DatabricksTaskOperator(DatabricksTaskBaseOperator): """ - Runs a task on Databricks using an Airflow operator. + Runs a task on Databricks using an Airflow operator. - The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow - tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which - allows users to run their tasks on cheaper clusters that can be shared between tasks. + The DatabricksTaskOperator allows users to launch and monitor task job runs on Databricks as Airflow + tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take advantage of job clusters, which + allows users to run their tasks on cheaper clusters that can be shared between tasks. - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:DatabricksTaskOperator` - - :param task_config: The configuration of the task to be run on Databricks. - :param databricks_conn_id: The name of the Airflow connection to use. - :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. - :param databricks_retry_delay: Number of seconds to wait between retries. - :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. - :param deferrable: Whether to run the operator in the deferrable mode. - :param existing_cluster_id: ID for existing cluster on which to run this task. - :param job_cluster_key: The key for the job cluster. - :param new_cluster: Specs for a new cluster on which this task will be run. - :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. - :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. - :param databricks_environments: An optional list of task execution environment specifications -that can be referenced by serverless tasks of this job + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DatabricksTaskOperator` + + :param task_config: The configuration of the task to be run on Databricks. + :param databricks_conn_id: The name of the Airflow connection to use. + :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param databricks_retry_delay: Number of seconds to wait between retries. + :param databricks_retry_limit: Amount of times to retry if the Databricks backend is unreachable. + :param deferrable: Whether to run the operator in the deferrable mode. + :param existing_cluster_id: ID for existing cluster on which to run this task. + :param job_cluster_key: The key for the job cluster. + :param new_cluster: Specs for a new cluster on which this task will be run. + :param polling_period_seconds: Controls the rate which we poll for the result of this notebook job run. + :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + :param databricks_environments: An optional list of task execution environment specifications + that can be referenced by serverless tasks of this job """ CALLER = "DatabricksTaskOperator" diff --git a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py index 7f2096f6a35b3..a350c23b6d1fe 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -335,7 +335,7 @@ def __exit__( for task in tasks: if not ( - hasattr(task, "_convert_to_databricks_workflow_task") + hasattr(task, "_convert_to_databricks_workflow_task.") and callable(task._convert_to_databricks_workflow_task) ): raise AirflowException( From 482bfb5be60c96b783ab6dc144920537c45ac9b8 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Sun, 22 Dec 2024 19:53:08 +0000 Subject: [PATCH 6/9] fix mypy error --- .../providers/databricks/operators/databricks_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py index a350c23b6d1fe..7f2096f6a35b3 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -335,7 +335,7 @@ def __exit__( for task in tasks: if not ( - hasattr(task, "_convert_to_databricks_workflow_task.") + hasattr(task, "_convert_to_databricks_workflow_task") and callable(task._convert_to_databricks_workflow_task) ): raise AirflowException( From 2a94bd361b84af8886cd93e72f1685f93cebc77c Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Mon, 23 Dec 2024 11:44:40 +0000 Subject: [PATCH 7/9] adding unit test cases --- .../databricks/operators/test_databricks.py | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/providers/tests/databricks/operators/test_databricks.py b/providers/tests/databricks/operators/test_databricks.py index da3c697360ffe..159e445710977 100644 --- a/providers/tests/databricks/operators/test_databricks.py +++ b/providers/tests/databricks/operators/test_databricks.py @@ -264,6 +264,15 @@ "permission_level": "CAN_MANAGE", } ] +ENVIRONMENTS = [ + { + "environment_key": "default_environment", + "spec": { + "client": "1", + "dependencies": ["library1"], + }, + } +] def mock_dict(d: dict): @@ -306,6 +315,7 @@ def test_init_with_named_parameters(self): max_concurrent_runs=MAX_CONCURRENT_RUNS, git_source=GIT_SOURCE, access_control_list=ACCESS_CONTROL_LIST, + databricks_environments=ENVIRONMENTS, ) expected = utils.normalise_json_content( { @@ -320,6 +330,7 @@ def test_init_with_named_parameters(self): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } ) @@ -341,6 +352,7 @@ def test_init_with_json(self): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) @@ -357,6 +369,7 @@ def test_init_with_json(self): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } ) @@ -379,6 +392,7 @@ def test_init_with_merging(self): override_max_concurrent_runs = 0 override_git_source = {} override_access_control_list = [] + override_environments = [] json = { "name": JOB_NAME, "tags": TAGS, @@ -391,6 +405,7 @@ def test_init_with_merging(self): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } op = DatabricksCreateJobsOperator( @@ -407,6 +422,7 @@ def test_init_with_merging(self): max_concurrent_runs=override_max_concurrent_runs, git_source=override_git_source, access_control_list=override_access_control_list, + databricks_environments=override_environments, ) expected = utils.normalise_json_content( @@ -422,6 +438,7 @@ def test_init_with_merging(self): "max_concurrent_runs": override_max_concurrent_runs, "git_source": override_git_source, "access_control_list": override_access_control_list, + "environments": override_environments, } ) @@ -465,6 +482,7 @@ def test_exec_create(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value @@ -489,6 +507,7 @@ def test_exec_create(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } ) db_mock_class.assert_called_once_with( @@ -521,6 +540,7 @@ def test_exec_reset(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) db_mock = db_mock_class.return_value @@ -543,6 +563,7 @@ def test_exec_reset(self, db_mock_class): "max_concurrent_runs": MAX_CONCURRENT_RUNS, "git_source": GIT_SOURCE, "access_control_list": ACCESS_CONTROL_LIST, + "environments": ENVIRONMENTS, } ) db_mock_class.assert_called_once_with( @@ -653,6 +674,79 @@ def test_init_with_spark_python_task_named_parameters(self): assert expected == utils.normalise_json_content(op.json) + def test_init_with_serverless_spark_python_task_named_parameters(self): + """ + Test the initializer with the named parameters. + """ + python_tasks = [ + { + "task_key": "pythong_task_1", + "new_cluster": { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "spark_conf": { + "spark.speculation": True, + }, + "aws_attributes": { + "availability": "SPOT", + "zone_id": "us-west-2a", + }, + "autoscale": { + "min_workers": 2, + "max_workers": 16, + }, + }, + "notebook_task": { + "notebook_path": "/Users/user.name@databricks.com/Match", + "source": "WORKSPACE", + "base_parameters": { + "name": "John Doe", + "age": "35", + }, + }, + "timeout_seconds": 86400, + "max_retries": 3, + "min_retry_interval_millis": 2000, + "retry_on_timeout": False, + "environment_key": "default_environment", + }, + ] + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": python_tasks, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + } + op = DatabricksSubmitRunOperator( + task_id=TASK_ID, + json=json, + databricks_environments=ENVIRONMENTS, + ) + expected = utils.normalise_json_content( + { + "name": JOB_NAME, + "tags": TAGS, + "tasks": python_tasks, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "environments": ENVIRONMENTS, + "run_name": TASK_ID, + } + ) + + assert expected == utils.normalise_json_content(op.json) + def test_init_with_pipeline_name_task_named_parameters(self): """ Test the initializer with the named parameters. From 0d72567cfa26b35c7a72bf08f3a6407c280ca452 Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Mon, 23 Dec 2024 20:43:25 +0000 Subject: [PATCH 8/9] adding and fixing unit test cases --- .../databricks/operators/databricks.py | 4 +++- .../databricks/operators/test_databricks.py | 22 ++++++++++++++----- .../operators/test_databricks_workflow.py | 12 ++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/providers/src/airflow/providers/databricks/operators/databricks.py b/providers/src/airflow/providers/databricks/operators/databricks.py index 08a7e7541b74a..98560affc07b0 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/src/airflow/providers/databricks/operators/databricks.py @@ -1162,7 +1162,9 @@ def _convert_to_databricks_workflow_task( } if self.existing_cluster_id and self.job_cluster_key: - self.log.info("The task %s will be executed in serverless mode", result["task_key"]) + raise ValueError( + "Both existing_cluster_id and job_cluster_key are set. Only one can be set per task." + ) if self.existing_cluster_id: result["existing_cluster_id"] = self.existing_cluster_id elif self.job_cluster_key: diff --git a/providers/tests/databricks/operators/test_databricks.py b/providers/tests/databricks/operators/test_databricks.py index 159e445710977..59d45cddd9515 100644 --- a/providers/tests/databricks/operators/test_databricks.py +++ b/providers/tests/databricks/operators/test_databricks.py @@ -2214,17 +2214,16 @@ def test_both_new_and_existing_cluster_set(self): exception_message = "Both new_cluster and existing_cluster_id are set. Only one should be set." assert str(exc_info.value) == exception_message - def test_both_new_and_existing_cluster_unset(self): + def test_both_new_and_existing_cluster_unset(self, caplog): operator = DatabricksNotebookOperator( task_id="test_task", notebook_path="test_path", source="test_source", databricks_conn_id="test_conn_id", ) - with pytest.raises(ValueError) as exc_info: - operator._get_run_json() - exception_message = "Must specify either existing_cluster_id or new_cluster." - assert str(exc_info.value) == exception_message + operator._get_run_json() + log_message = "The task adhoc_airflow__test_task will be executed in serverless mode" + assert log_message in caplog.text def test_job_runs_forever_by_default(self): operator = DatabricksNotebookOperator( @@ -2411,3 +2410,16 @@ def test_get_task_base_json(self): assert operator.task_config == task_config assert task_base_json == task_config + + def test_get_task_base_json_serverless(self): + task_config = SPARK_PYTHON_TASK + operator = DatabricksTaskOperator( + task_id="test_task", + databricks_conn_id="test_conn_id", + task_config=task_config, + databricks_environments=ENVIRONMENTS, + ) + task_base_json = operator._get_task_base_json() + + assert operator.task_config == task_config + assert task_base_json == task_config diff --git a/providers/tests/databricks/operators/test_databricks_workflow.py b/providers/tests/databricks/operators/test_databricks_workflow.py index fbc429ed1d9a8..c1a98c0a5cc03 100644 --- a/providers/tests/databricks/operators/test_databricks_workflow.py +++ b/providers/tests/databricks/operators/test_databricks_workflow.py @@ -77,9 +77,19 @@ def test_flatten_node(): def test_create_workflow_json(mock_databricks_hook, context, mock_task_group): """Test that _CreateDatabricksWorkflowOperator.create_workflow_json returns the expected JSON.""" + environments = [ + { + "environment_key": "default_environment", + "spec": { + "client": "1", + "dependencies": ["library1"], + }, + } + ] operator = _CreateDatabricksWorkflowOperator( task_id="test_task", databricks_conn_id="databricks_default", + databricks_environments=environments, ) operator.task_group = mock_task_group @@ -96,6 +106,7 @@ def test_create_workflow_json(mock_databricks_hook, context, mock_task_group): assert workflow_json["job_clusters"] == [] assert workflow_json["max_concurrent_runs"] == 1 assert workflow_json["timeout_seconds"] == 0 + assert workflow_json["environments"] == environments def test_create_or_reset_job_existing(mock_databricks_hook, context, mock_task_group): @@ -216,6 +227,7 @@ def test_task_group_exit_creates_operator(mock_databricks_workflow_operator): task_group=task_group, task_id="launch", databricks_conn_id="databricks_conn", + databricks_environments=[], existing_clusters=[], extra_job_params={}, job_clusters=[], From c7a9d5e4a2f64989c78b720ae713c214bbe4814c Mon Sep 17 00:00:00 2001 From: hari-selvarajan_data Date: Mon, 23 Dec 2024 21:23:05 +0000 Subject: [PATCH 9/9] added documentation with examples --- .../operators/jobs_create.rst | 1 + .../operators/submit_run.rst | 1 + .../operators/task.rst | 8 +++++++ .../databricks/operators/test_databricks.py | 9 +------- .../system/databricks/example_databricks.py | 23 +++++++++++++++++++ 5 files changed, 34 insertions(+), 8 deletions(-) diff --git a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst index 621423f83f32d..aeb8668e1fc89 100644 --- a/docs/apache-airflow-providers-databricks/operators/jobs_create.rst +++ b/docs/apache-airflow-providers-databricks/operators/jobs_create.rst @@ -56,6 +56,7 @@ Currently the named parameters that ``DatabricksCreateJobsOperator`` supports ar - ``max_concurrent_runs`` - ``git_source`` - ``access_control_list`` + - ``databricks_environments`` Examples diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst b/docs/apache-airflow-providers-databricks/operators/submit_run.rst index 10548583cfa3f..dbdd5dd715137 100644 --- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst +++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst @@ -80,6 +80,7 @@ Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are - ``libraries`` - ``run_name`` - ``timeout_seconds`` + - ``databricks_environments`` .. code-block:: python diff --git a/docs/apache-airflow-providers-databricks/operators/task.rst b/docs/apache-airflow-providers-databricks/operators/task.rst index 331481d915c4b..503ce283c326f 100644 --- a/docs/apache-airflow-providers-databricks/operators/task.rst +++ b/docs/apache-airflow-providers-databricks/operators/task.rst @@ -44,3 +44,11 @@ Running a SQL query in Databricks using DatabricksTaskOperator :language: python :start-after: [START howto_operator_databricks_task_sql] :end-before: [END howto_operator_databricks_task_sql] + +Running a python file in Databricks in using DatabricksTaskOperator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. exampleinclude:: /../../providers/tests/system/databricks/example_databricks.py + :language: python + :start-after: [START howto_operator_databricks_task_python] + :end-before: [END howto_operator_databricks_task_python] + diff --git a/providers/tests/databricks/operators/test_databricks.py b/providers/tests/databricks/operators/test_databricks.py index 59d45cddd9515..aca0c95a1acee 100644 --- a/providers/tests/databricks/operators/test_databricks.py +++ b/providers/tests/databricks/operators/test_databricks.py @@ -696,14 +696,7 @@ def test_init_with_serverless_spark_python_task_named_parameters(self): "max_workers": 16, }, }, - "notebook_task": { - "notebook_path": "/Users/user.name@databricks.com/Match", - "source": "WORKSPACE", - "base_parameters": { - "name": "John Doe", - "age": "35", - }, - }, + "spark_python_task": {"python_file": "/Users/jsmith@example.com/example_file.py"}, "timeout_seconds": 86400, "max_retries": 3, "min_retry_interval_millis": 2000, diff --git a/providers/tests/system/databricks/example_databricks.py b/providers/tests/system/databricks/example_databricks.py index 999cebb674292..f178b4860ec71 100644 --- a/providers/tests/system/databricks/example_databricks.py +++ b/providers/tests/system/databricks/example_databricks.py @@ -238,6 +238,29 @@ ) # [END howto_operator_databricks_task_sql] + # [START howto_operator_databricks_task_python] + environments = [ + { + "environment_key": "default_environment", + "spec": { + "client": "1", + "dependencies": ["library1"], + }, + } + ] + task_operator_python_query = DatabricksTaskOperator( + task_id="python_task", + databricks_conn_id="databricks_conn", + task_config={ + "spark_python_task": { + "python_file": "/Users/jsmith@example.com/example_file.py", + }, + "environment_key": "default_environment", + }, + databricks_environments=environments, + ) + # [END howto_operator_databricks_task_python] + from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure