Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions src/dataworkbench/datacatalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,21 @@ def __init__(self) -> None:
self.gateway: Gateway = Gateway()
self.storage_base_url: str = get_secret("StorageBaseUrl")

def __build_storage_url(self, folder_id: uuid.UUID) -> str:
def __build_storage_table_root_url(self, folder_id: uuid.UUID) -> str:
"""
Build the ABFSS URL for the target storage location.
Build the ABFSS URL for the root location of the table
"""
if not isinstance(folder_id, uuid.UUID):
raise TypeError("folder_id must be uuid")

if not folder_id:
raise ValueError("folder_id cannot be empty")

return f"{self.storage_base_url}/{folder_id}"

def __build_storage_table_processed_url(self, folder_id: uuid.UUID) -> str:
"""
Build the ABFSS URL for the processed table storage location.

Args:
folder_id: Unique identifier for the storage folder
Expand All @@ -51,15 +63,10 @@ def __build_storage_url(self, folder_id: uuid.UUID) -> str:

Example:
>>> catalogue = DataCatalogue()
>>> catalogue._build_storage_url("abc123")
>>> catalogue.__build_storage_table_processed_url("abc123")
"""
if not isinstance(folder_id, uuid.UUID):
raise TypeError("folder_id must be uuid")

if not folder_id:
raise ValueError("folder_id cannot be empty")

return f"{self.storage_base_url}/{folder_id}/Processed"
table_root_url = self.__build_storage_table_root_url(folder_id)
return f"{table_root_url}/Processed"

def save(
self,
Expand Down Expand Up @@ -118,15 +125,45 @@ def save(
# Generate folder_id
folder_id = uuid.uuid4()

target_path = self.__build_storage_url(folder_id)
target_path = self.__build_storage_table_processed_url(folder_id)

try:
# Write data using the specified or defaulted mode
self.storage.write(df, target_path, mode=WriteMode.OVERWRITE.value)

return self.gateway.import_dataset(
dataset_name, dataset_description, schema_id, tags or {}, folder_id
)
try:
# Register the dataset with the Gateway API
return self.gateway.import_dataset(
dataset_name, dataset_description, schema_id, tags or {}, folder_id
)
except Exception as e:
self._rollback_write(folder_id)

# Raise the original API error with additional context
error_msg = (
f"Gateway API call failed and storage was rolled back: {str(e)}"
)
raise type(e)(error_msg) from e

except Exception as e:
return {"error": str(e), "error_type": type(e).__name__}

def _rollback_write(self, folder_id: uuid.UUID) -> None:
"""
Delete table from storage to rollback changes when an operation fails.

Args:
target_path: Path to the data in storage that should be deleted
"""
target_path = self.__build_storage_table_root_url(folder_id)
logger.info("Rolling back data write operation to storage")
try:
self.storage.delete(target_path, recursive=True)
except Exception as rollback_error:
logger.error(
f"Failed to rollback storage operation at {target_path}: {str(rollback_error)}"
)

logger.info(
f"Successfully rolled back data write operation by deleting: {target_path}"
)
2 changes: 1 addition & 1 deletion src/dataworkbench/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@ def import_dataset(
return self.__send_request(url, payload)
except requests.exceptions.RequestException as e:
logger.error(f"Error creating data catalog entry: {e}")
return {"error": f"Failed to create data catalog entry: {str(e)}"}
raise
84 changes: 73 additions & 11 deletions src/dataworkbench/storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Literal
from typing import Literal
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.utils import AnalysisException
from abc import ABC, abstractmethod
from dataworkbench.utils import get_dbutils, PrimitiveType, is_databricks

from dataworkbench.log import setup_logger


# Configure logging
logger = setup_logger(__name__)

Expand All @@ -24,7 +26,7 @@ def write(
df: DataFrame,
target_path: str,
mode: Literal["overwrite", "append", "error", "ignore"] = "overwrite",
**options: dict[str, Any],
**options: PrimitiveType | None,
) -> None:
"""
Write a DataFrame to storage.
Expand Down Expand Up @@ -60,7 +62,7 @@ def check_path_exists(self, path: str) -> bool:
pass

@abstractmethod
def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame:
def read(self, source_path: str, **options: PrimitiveType | None) -> DataFrame:
"""
Read data from storage into a DataFrame.

Expand Down Expand Up @@ -98,6 +100,7 @@ def __init__(self, spark_session: SparkSession | None = None):
raise TypeError("spark_session must be a SparkSession or None")

self._spark = spark_session
self._dbutils = get_dbutils(self._spark)

@property
def spark(self) -> SparkSession:
Expand Down Expand Up @@ -127,7 +130,8 @@ def write(
df: DataFrame,
target_path: str,
mode: Literal["overwrite", "append", "error", "ignore"] = "overwrite",
**options: dict[str, Any],
partition_by: str | list[str] | None = None,
**options: PrimitiveType | None,
) -> None:
"""
Write a DataFrame to storage in Delta format.
Expand Down Expand Up @@ -172,8 +176,11 @@ def write(
writer = df.write.format("delta").mode(mode)

# Apply options if provided
for key, value in options.items():
writer = writer.option(key, value)
if options:
writer = writer.options(**options)

if partition_by:
writer = writer.partitionBy(partition_by)

# Save the data
writer.save(target_path)
Expand All @@ -189,7 +196,7 @@ def append(
df: DataFrame,
target_path: str,
partition_by: str | list[str] | None = None,
**options: dict[str, Any],
**options: PrimitiveType | None,
) -> None:
"""
Append a DataFrame to existing data in Delta format.
Expand All @@ -213,7 +220,13 @@ def append(
>>> new_records = spark.createDataFrame([("Charlie", 35)], ["name", "age"])
>>> storage.append(new_records, "abfss://container@account.dfs.core.windows.net/path/to/data")
"""
self.write(df, target_path, mode="append", partition_by=partition_by, **options)
self.write(
df=df,
target_path=target_path,
mode="append",
partition_by=partition_by,
**options,
)

def check_path_exists(self, path: str) -> bool:
"""
Expand Down Expand Up @@ -247,7 +260,7 @@ def check_path_exists(self, path: str) -> bool:
logger.warning(f"Error checking path existence: {e}")
return False

def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame:
def read(self, source_path: str, **options: PrimitiveType | None) -> DataFrame:
"""
Read a Delta table from storage into a DataFrame.

Expand All @@ -274,8 +287,8 @@ def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame:
reader = self.spark.read.format("delta")

# Apply options if provided
for key, value in options.items():
reader = reader.option(key, value)
if options:
reader = reader.options(**options)

# Load the data
return reader.load(source_path)
Expand All @@ -284,3 +297,52 @@ def read(self, source_path: str, **options: dict[str, Any]) -> DataFrame:
error_msg = f"Failed to read data from {source_path}: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e

def file_exists(self, path: str):
if is_databricks():
try:
self._dbutils.fs.ls(path)
return True
except Exception as e:
if "java.io.FileNotFoundException" in str(e):
return False
else:
raise
else:
logger.info("This method is not implemented outside databricks")

def delete(self, path: str, recursive: bool = True) -> None:
"""
Delete a directory from Azure Storage using Spark.

Args:
path: The path to the file / directory in Azure Storage to delete
recursive: If True, recursively delete all subdirectories and files

Raises:
TypeError: If path is not a string
ValueError: If path is empty
Exception: If any error occurs during deletion
"""
if not is_databricks():
raise RuntimeError("Delete does not work outside databricks")

if not isinstance(path, str):
raise TypeError("path must be a non-empty string")

if not path:
raise ValueError("path cannot be empty")

try:
logger.info(f"Deleting path: {path}, recursive={recursive}")

if not self.file_exists(path):
logger.warning(f"Path does not exist, nothing to delete: {path}")
return

# Delete the path
self._dbutils.fs.rm(path, recurse=True)

except Exception as e:
logger.error(f"Failed to delete {path}: {str(e)}")
raise Exception(f"Failed to delete: {str(e)}") from e
37 changes: 28 additions & 9 deletions src/dataworkbench/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os
from pyspark.sql import SparkSession

from dataworkbench.log import setup_logger

# Configure logging
logger = setup_logger(__name__)


PrimitiveType = str | int | float | bool


def get_spark() -> SparkSession:
"""
Expand All @@ -23,24 +31,35 @@ def is_databricks():
return os.getenv("DATABRICKS_RUNTIME_VERSION") is not None


def get_secret(key: str, scope: str = "dwsecrets") -> str:
def get_dbutils(spark: SparkSession | None = None):
"""
Retrieve a secret from dbutils if running on Databricks, otherwise fallback to env variables.
Get dbutils module
"""

secret = None # Default value

if is_databricks():
try:
from pyspark.dbutils import DBUtils # type: ignore

spark = get_spark()
dbutils = DBUtils(spark)
secret = dbutils.secrets.get(scope, key)
except ImportError:
raise RuntimeError(
"dbutils module not found. Ensure this is running on Databricks."
)
try:
return DBUtils(spark)
except Exception as e:
logger.error(f"Failed to create dbutils: {e}")
raise RuntimeError("No dbutils available") from e
else:
return None


def get_secret(key: str, scope: str = "dwsecrets") -> str:
"""
Retrieve a secret from dbutils if running on Databricks, otherwise fallback to env variables.
"""

dbutils = get_dbutils()

if dbutils:
secret = dbutils.secrets.get(scope, key)
else:
secret = os.getenv(key)

Expand Down
5 changes: 3 additions & 2 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
from unittest.mock import patch, MagicMock
from dataworkbench.gateway import Gateway
from requests.exceptions import RequestException

@pytest.fixture
def mock_gateway():
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_import_dataset_failure(mock_gateway, mock_post):
"""Test dataset import failure."""
mock_post.side_effect = requests.exceptions.RequestException("Request failed")

result = mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id")
with pytest.raises(RequestException):
mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id")

assert result == {"error": "Failed to create data catalog entry: Request failed"}
mock_post.assert_called_once()