Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/install.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Install Tests
on:
pull_request:
types: [opened, synchronize]
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
install:
name: ${{ matrix.python_version }} install
strategy:
fail-fast: true
matrix:
python_version: ["3.8", "3.13"]
runs-on: ubuntu-latest
steps:
- name: Set up python ${{ matrix.python_version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- uses: actions/checkout@v4
- name: Build package
run: |
make package
- name: Install package
run: |
python -m pip install "unpacked_sdist/."
- name: Test by importing packages
run: |
python -c "import sdmetrics"
- name: Check package conflicts
run: |
python -m pip check
20 changes: 20 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,23 @@ release-minor: check-release bumpversion-minor release

.PHONY: release-major
release-major: check-release bumpversion-major release

# Packaging Targets
.PHONY: upgradepip
upgradepip:
python -m pip install --upgrade pip

.PHONY: upgradebuild
upgradebuild:
python -m pip install --upgrade build

.PHONY: upgradesetuptools
upgradesetuptools:
python -m pip install --upgrade setuptools

.PHONY: package
package: upgradepip upgradebuild upgradesetuptools
python -m build ; \
$(eval VERSION=$(shell python -c 'import setuptools; setuptools.setup()' --version))
tar -zxvf "dist/sdmetrics-${VERSION}.tar.gz"
mv "sdmetrics-${VERSION}" unpacked_sdist
9 changes: 8 additions & 1 deletion sdmetrics/single_table/bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

import numpy as np
import pandas as pd
import torch

from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric

try:
import torch
except ModuleNotFoundError:
torch = None

LOGGER = logging.getLogger(__name__)


Expand All @@ -19,6 +23,9 @@ class BNLikelihoodBase(SingleTableMetric):
def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
try:
from pomegranate.bayesian_network import BayesianNetwork

if torch is None:
raise ImportError
except ImportError:
raise ImportError(
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/single_table/test_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,20 @@ def metadata():
class TestBNLikelihood:
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
def test_compute_error(self):
"""Test that an `ImportError` is raised."""
"""Test that an `ImportError` is raised when pomegranate isn't installed."""
# Setup
metric = BNLikelihood()

# Run and Assert
expected_message = re.escape(
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
)
with pytest.raises(ImportError, match=expected_message):
metric.compute(Mock(), Mock())

@patch.dict('sys.modules', {'torch': None})
def test_compute_error_torch_is_none(self):
"""Test that an `ImportError` is raised when torch isn't installed."""
# Setup
metric = BNLikelihood()

Expand Down
Loading