Skip to content

Commit 59adf4e

Browse files
committed
Some additinal type hinting and misc cleanups
1 parent 61be769 commit 59adf4e

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

slots/slots.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
mab.online_trial(bandit=1, payout=0)
1919
"""
2020

21-
from typing import Optional, List, Dict, Any, Union
21+
from typing import Optional, List, Dict, Any, Union, Callable
2222

2323
import numpy as np
2424

@@ -33,7 +33,7 @@ def __init__(
3333
num_bandits: Optional[int] = 3,
3434
probs: Optional[np.ndarray] = None,
3535
hist_payouts: Optional[List[np.ndarray]] = None,
36-
live: bool = False,
36+
live: Optional[bool] = False,
3737
stop_criterion: Optional[Dict] = {"criterion": "regret", "value": 0.1},
3838
) -> None:
3939
"""
@@ -59,7 +59,7 @@ def __init__(
5959
if live:
6060
# Live trial scenario, where nothing is known except the
6161
# number of bandits
62-
self.bandits = Bandits(
62+
self.bandits: Bandits = Bandits(
6363
live=True, payouts=np.zeros(num_bandits)
6464
)
6565
else:
@@ -109,7 +109,9 @@ def __init__(
109109
self.pulls: np.ndarray = np.zeros(num_bandits)
110110

111111
# Set the stopping criteria
112-
self.criteria: Dict = {"regret": self.regret_met}
112+
self.criteria: Dict[str, Callable[[Optional[float]], bool]] = {
113+
"regret": self.regret_met
114+
}
113115
if not stop_criterion:
114116
self.criterion: str = "regret"
115117
self.stop_value: float = 0.1
@@ -243,14 +245,14 @@ def bayesian(self, params: Any = None) -> int:
243245
int
244246
Index of chosen bandit
245247
"""
246-
p_success_arms = [
248+
p_success_arms: List[float] = [
247249
np.random.beta(self.wins[i] + 1, self.pulls[i] - self.wins[i] + 1)
248250
for i in range(len(self.wins))
249251
]
250252

251253
return np.array(p_success_arms).argmax()
252254

253-
def eps_greedy(self, params: Optional[Dict] = None) -> int:
255+
def eps_greedy(self, params: Optional[Dict[str, float]] = None) -> int:
254256
"""
255257
Run the epsilon-greedy strategy and update self.max_mean()
256258
@@ -265,12 +267,19 @@ def eps_greedy(self, params: Optional[Dict] = None) -> int:
265267
Index of chosen bandit
266268
"""
267269

270+
default_eps: float = 0.1
271+
268272
if params and type(params) == dict:
269-
eps = params.get("epsilon")
273+
eps: float = params.get("epsilon", default_eps)
274+
try:
275+
float(eps)
276+
except ValueError:
277+
print("slots: eps_greedy: Setting eps to default")
278+
eps = default_eps
270279
else:
271-
eps = 0.1
280+
eps = default_eps
272281

273-
r = np.random.rand()
282+
r: int = np.random.rand()
274283

275284
if r < eps:
276285
return np.random.choice(
@@ -301,7 +310,7 @@ def softmax(self, params: Optional[Dict] = None) -> int:
301310
try:
302311
float(tau)
303312
except ValueError:
304-
"slots: softmax: Setting tau to default"
313+
print("slots: softmax: Setting tau to default")
305314
tau = default_tau
306315
else:
307316
tau = default_tau
@@ -582,7 +591,7 @@ def pull(self, i: int) -> Optional[int]:
582591
if not self.hist_payouts[i]:
583592
return None
584593
else:
585-
_p = self.hist_payouts[i][0]
594+
_p: int = self.hist_payouts[i][0]
586595
self.hist_payouts[i] = self.hist_payouts[i][1:]
587596
return _p
588597
else:

0 commit comments

Comments
 (0)