Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions sync/databricks/integrations/airflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from sync.databricks.integrations._run_submit_runner import apply_sync_gradient_cluster_recommendation

from sync.databricks.integrations._run_submit_runner import (
apply_sync_gradient_cluster_recommendation,
)

logger = logging.getLogger(__name__)

Expand All @@ -9,21 +12,28 @@ def airflow_gradient_pre_execute_hook(context: dict):
logger.info("Running airflow gradient pre-execute hook!")
logger.debug(f"Airflow operator context - context:{context}")

task_id = context["task"].task_id
gradient_app_id = context["params"]["gradient_app_id"]
auto_apply = context["params"]["gradient_auto_apply"]
cluster_log_url = context["params"]["cluster_log_url"]
workspace_id = context["params"]["databricks_workspace_id"]
run_submit_task = context["task"].json.copy() # copy the run submit json from the task context
run_submit_task = context[
"task"
].json.copy() # copy the run submit json from the task context

updated_task_configuration = apply_sync_gradient_cluster_recommendation(
run_submit_task=run_submit_task,
gradient_app_id=gradient_app_id,
gradient_app_id=build_app_id(task_id, gradient_app_id),
auto_apply=auto_apply,
cluster_log_url=cluster_log_url,
workspace_id=workspace_id
workspace_id=workspace_id,
)

context["task"].json = updated_task_configuration
except Exception as e:
logger.exception(e)
logger.error("Unable to apply gradient configuration to Databricks run submit tasks")


def build_app_id(task_id: str, app_id: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def build_app_id(task_id: str, app_id: str):
def build_app_id(task_id: str, app_id: str) -> str:

return f"{task_id}-{app_id}"