Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions skrules/datasets/credit_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@

import pandas as pd
import numpy as np
from sklearn.datasets.base import get_data_home, Bunch
from sklearn.datasets.base import _fetch_remote, RemoteFileMetadata

try:
from sklearn.datasets.base import get_data_home, Bunch, _fetch_remote, RemoteFileMetadata
except (ModuleNotFoundError, ImportError):
# For scikit-learn >= 0.24 compatibility
from sklearn.datasets import get_data_home
from sklearn.utils import Bunch
from sklearn.datasets._base import _fetch_remote, RemoteFileMetadata

from os.path import exists, join


Expand All @@ -33,6 +40,8 @@ def load_credit_data():
'011238620f5369220bd60cfc82700933'))

if not exists(join(sk_data_dir, archive.filename)):
import socket
socket.setdefaulttimeout(180)
_fetch_remote(archive, dirname=sk_data_dir)

data = pd.read_excel(join(sk_data_dir, archive.filename),
Expand Down
20 changes: 19 additions & 1 deletion skrules/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
from sklearn.utils.estimator_checks import check_estimator
from skrules import SkopeRules
from skrules.datasets import load_credit_data
import sklearn


def test_classifier():
check_estimator(SkopeRules)
try:
check_estimator(SkopeRules)
except TypeError:
# For sklearn >= 0.24.0 compatibility
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import check_sample_weights_invariance

checks = check_estimator(SkopeRules(), generate_only=True)
for estimator, check in checks:
# Here we ignore this particular estimator check because
# sample weights are treated differently in skope-rules
if check.func != check_sample_weights_invariance:
try:
check(estimator)
except SkipTest as exception:
# SkipTest is thrown when pandas can't be imported, or by checks
# that are in the xfail_checks tag
warnings.warn(str(exception), SkipTestWarning)


def test_load_credit_data():
Expand Down
57 changes: 20 additions & 37 deletions skrules/tests/test_rule.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
from sklearn.utils.testing import assert_equal, assert_not_equal

from skrules import Rule, replace_feature_name


def test_rule():
assert_equal(Rule('a <= 10 and a <= 12'),
Rule('a <= 10'))
assert_equal(Rule('a <= 10 and a <= 12 and a > 3'),
Rule('a > 3 and a <= 10'))
assert Rule("a <= 10 and a <= 12") == Rule("a <= 10")
assert Rule("a <= 10 and a <= 12 and a > 3") == Rule("a > 3 and a <= 10")

assert_equal(Rule('a <= 10 and a <= 10 and a > 3'),
Rule('a > 3 and a <= 10'))
assert Rule("a <= 10 and a <= 10 and a > 3") == Rule("a > 3 and a <= 10")

assert_equal(Rule('a <= 10 and a <= 12 and b > 3 and b > 6'),
Rule('a <= 10 and b > 6'))
assert Rule("a <= 10 and a <= 12 and b > 3 and b > 6") == Rule("a <= 10 and b > 6")

assert_equal(len({Rule('a <= 2 and a <= 3'),
Rule('a <= 2')
}), 1)
assert len({Rule("a <= 2 and a <= 3"), Rule("a <= 2")}) == 1

assert_equal(len({Rule('a > 2 and a > 3 and b <= 2 and b <= 3'),
Rule('a > 3 and b <= 2')
}), 1)
assert (
len({Rule("a > 2 and a > 3 and b <= 2 and b <= 3"), Rule("a > 3 and b <= 2")})
== 1
)

assert_equal(len({Rule('a <= 3 and b <= 2'),
Rule('b <= 2 and a <= 3')
}), 1)
assert len({Rule("a <= 3 and b <= 2"), Rule("b <= 2 and a <= 3")}) == 1


def test_hash_rule():
assert_equal(len({
Rule('a <= 2 and a <= 3'),
Rule('a <= 2')
}), 1)
assert_not_equal(len({
Rule('a <= 4 and a <= 3'),
Rule('a <= 2')
}), 1)
assert len({Rule("a <= 2 and a <= 3"), Rule("a <= 2")}) == 1
assert len({Rule("a <= 4 and a <= 3"), Rule("a <= 2")}) != 1


def test_str_rule():
rule = 'a <= 10.0 and b > 3.0'
assert_equal(rule, str(Rule(rule)))
rule = "a <= 10.0 and b > 3.0"
assert rule == str(Rule(rule))


def test_equals_rule():
rule = "a == a"
assert_equal(rule, str(Rule(rule)))
assert rule == str(Rule(rule))

rule2 = "a == a and a == a"
assert_equal(rule, str(Rule(rule2)))
assert rule == str(Rule(rule2))

rule3 = "a < 3.0 and a == a"
assert_equal(rule3, str(Rule(rule3)))
assert rule3 == str(Rule(rule3))


def test_replace_feature_name():
rule = "__C__0 <= 3 and __C__1 > 4"
real_rule = "$b <= 3 and c(4) > 4"
replace_dict = {
"__C__0": "$b",
"__C__1": "c(4)"
}
assert_equal(replace_feature_name(rule, replace_dict=replace_dict), real_rule)
replace_dict = {"__C__0": "$b", "__C__1": "c(4)"}
assert replace_feature_name(rule, replace_dict=replace_dict) == real_rule

Loading