Skip to content

Commit 47f649e

Browse files
authored
Optuna 3.1 support (#33)
* optuna 3.1 debug * cleanup * cleanup * cleanup
1 parent 5f9ba95 commit 47f649e

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ install_requires =
4646
torch>=1.10.0
4747
numpy>=1.20
4848
catboost>=1.0.5
49-
optuna>=2.10
49+
optuna>=3.1
5050
loguru==.0.6.0
5151
xgboost>=1.6.1
5252
miracle-imputation>=0.1.3

src/hyperimpute/utils/optimizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# third party
66
import optuna
7+
from optuna.storages import JournalRedisStorage, JournalStorage
78
import redis
89

910
# hyperimpute absolute
@@ -21,17 +22,16 @@ def __init__(
2122
):
2223
self.url = f"redis://{host}:{port}/"
2324

24-
self._optuna_storage = optuna.storages.RedisStorage(url=self.url)
25+
self._optuna_storage = JournalStorage(JournalRedisStorage(url=self.url))
2526
self._client = redis.Redis.from_url(self.url)
2627

27-
def optuna(self) -> optuna.storages.RedisStorage:
28+
def optuna(self) -> JournalStorage:
2829
return self._optuna_storage
2930

3031
def client(self) -> redis.Redis:
3132
return self._client
3233

3334

34-
backend = RedisBackend()
3535
threshold = 40
3636

3737

@@ -104,7 +104,11 @@ def create_study(
104104
patience: int = threshold,
105105
) -> Tuple[optuna.Study, ParamRepeatPruner]:
106106

107-
storage_obj = backend.optuna()
107+
try:
108+
backend = RedisBackend()
109+
storage_obj = backend.optuna()
110+
except BaseException:
111+
storage_obj = None
108112

109113
try:
110114
study = optuna.create_study(

src/hyperimpute/version.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
__version__ = "0.1.12"
2-
MAJOR_VERSION = "0.1"
1+
__version__ = "0.1.13"
2+
3+
MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
4+
MINOR_VERSION = __version__.split(".")[-1]

0 commit comments

Comments
 (0)