Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/dataworkbench/datacatalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pyspark.sql import DataFrame

from dataworkbench.utils import get_secret
from dataworkbench.utils import get_secret, SparkDataFrame
from dataworkbench.storage import DeltaStorage
from dataworkbench.gateway import Gateway

Expand Down Expand Up @@ -46,9 +46,6 @@ def __build_storage_table_root_url(self, folder_id: uuid.UUID) -> str:
if not isinstance(folder_id, uuid.UUID):
raise TypeError("folder_id must be uuid")

if not folder_id:
raise ValueError("folder_id cannot be empty")

return f"{self.storage_base_url}/{folder_id}"

def __build_storage_table_processed_url(self, folder_id: uuid.UUID) -> str:
Expand Down Expand Up @@ -110,7 +107,7 @@ def save(
... )
"""
# Validate input parameters
if not hasattr(df, "write"):
if not isinstance(df, SparkDataFrame):
raise TypeError("df must be a DataFrame")

if not isinstance(dataset_name, str) or not dataset_name:
Expand Down
11 changes: 6 additions & 5 deletions src/dataworkbench/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def import_dataset(
if e.response is not None
else None
)
logger.error(f"Error creating data catalog entry: {e}")
return {
"error": "Failed to create data catalog entry.",
"correlation-id": trace_id,
}
error_msg = (
f"Failed to create data catalog entry. correlation-id: {trace_id}"
)

logger.error(error_msg)
raise type(e)(error_msg) from e
10 changes: 9 additions & 1 deletion src/dataworkbench/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, DataFrame

from dataworkbench.log import setup_logger

Expand Down Expand Up @@ -70,6 +70,14 @@ def get_secret(key: str, scope: str = "dwsecrets") -> str:
return secret


if is_databricks():
from pyspark.sql.connect.dataframe import DataFrame as DatabricksDataFrame

SparkDataFrame = DataFrame | DatabricksDataFrame
else:
SparkDataFrame = DataFrame


# Example usage
if __name__ == "__main__":
CLIENT_ID = get_secret("ClientId")
Expand Down
7 changes: 5 additions & 2 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
from unittest.mock import patch, MagicMock
from dataworkbench.gateway import Gateway
from requests.exceptions import RequestException
import json

@pytest.fixture
Expand Down Expand Up @@ -32,6 +33,7 @@ def test_import_dataset_success(mock_gateway, mock_post):
assert result == {"status": "success"}
mock_post.assert_called_once()


def test_import_dataset_failure(mock_gateway, mock_post):
"""Test dataset import failure."""

Expand All @@ -47,7 +49,8 @@ def test_import_dataset_failure(mock_gateway, mock_post):
mock_response.raise_for_status.side_effect = http_error
mock_post.return_value = mock_response

result = mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id")
with pytest.raises(RequestException) as e:
mock_gateway.import_dataset("dataset_name", "dataset_description", "schema_id", {"tag": "value"}, "folder_id")

assert result == {"error": "Failed to create data catalog entry.", "correlation-id": response_body["traceId"]}
assert e.value.args[0] == f"Failed to create data catalog entry. correlation-id: {response_body['traceId']}"
mock_post.assert_called_once()