Skip to content
This repository is currently being migrated. It's locked while the migration is in progress.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions integration/Gradient-Train-Job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# MAGIC * The Gradient Webhook has been configured
# MAGIC * The Databricks Job has been Gradient enabled
# MAGIC
# MAGIC When bypassing the Gradient Webhook with AWS, the cluster attached to this notebook must have an instance_arn with describe_instances and describe_volumes permissions. When bypassing the Gradient Webhook with Azure, the following environment variables must be set with the correct values: "AZURE_TENANT_ID", "AZURE_SUBSCRIPTION_ID", "AZURE_CLIENT_SECRET", "AZURE_CLIENT_ID"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added this in case it's unclear what customers would need to do to bypass the webhook. I'm fine with taking it out if bypassing the webhook is supposed to be a complete edge case

# MAGIC
# MAGIC This job will configure all runs to execute using ON DEMAND nodes only. The orginal settings will be restored after training is complete.
Copy link
Contributor Author

@CaymanWilliams CaymanWilliams Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? Are the settings ever restored? (Supposed to be looking at the line under the highlighted one)

# MAGIC

Expand All @@ -41,7 +43,6 @@
os.environ["SYNC_API_KEY_SECRET"] = dbutils.widgets.get("Sync API Key Secret")
os.environ["DATABRICKS_HOST"] = dbutils.widgets.get("Databricks Host").rstrip('\/')


print(f"DATABRICKS_JOB_ID: {DATABRICKS_JOB_ID}")
print(f"TRAINING_RUNS: {TRAINING_RUNS}")
print(f"BYPASS_WEBHOOK: {BYPASS_WEBHOOK}")
Expand Down Expand Up @@ -73,18 +74,19 @@
else:
raise ValueError(f"Unsupported platform: {platform}")

if BYPASS_WEBHOOK:
access_report = sync_databricks.get_access_report()

access_report = sync_databricks.get_access_report()
for line in access_report:
logger.info(line)

for line in access_report:
logger.info(line)

assert not any(line.status is AccessStatusCode.RED for line in access_report), "Required access is missing"
assert not any(line.status is AccessStatusCode.RED for line in access_report), "Required access is missing"

# COMMAND ----------

from typing import Optional

def get_cluster_for_job(job: dict | None) -> dict:
def get_cluster_for_job(job: Optional[dict]) -> dict:

if job is None:
job = sync_databricks_client.get_job(DATABRICKS_JOB_ID)
Expand All @@ -110,11 +112,10 @@ def get_cluster_for_job(job: dict | None) -> dict:
else:
raise ValueError("Could not identify a cluster for this job")

def get_tag_for_job(job: dict, tag_key: str) -> str | None:
def get_tag_for_job(job: dict, tag_key: str) -> Optional[str]:
cluster = get_cluster_for_job(job)
return cluster["custom_tags"].get(tag_key)


def validate_job():
logger.info("Validating Databricks Job")
job = sync_databricks_client.get_job(DATABRICKS_JOB_ID)
Expand Down Expand Up @@ -204,7 +205,7 @@ def validate_job():
# COMMAND ----------

class RecommendationError(Exception):
"Raised something goes wrong with the generation of a GradientML Recommendation"
"Raised when something goes wrong with the generation of a GradientML Recommendation"

def __init__(self, error):
super().__init__("recommendation Error: " + str(error))
Expand All @@ -227,7 +228,7 @@ def run_job(run_job_id: str):
break
return run

def wait_for_recommendation(starting_recommendation_id: str | None) -> None:
def wait_for_recommendation(starting_recommendation_id: Optional[str]) -> None:
logger.info(f"waiting for log submission and rec generation and application")
logger.info(f"starting recommendation id: {starting_recommendation_id}")

Expand Down