Skip to content

Commit 7d17356

Browse files
committed
Initial type hints
1 parent d23e392 commit 7d17356

File tree

3 files changed

+94
-59
lines changed

3 files changed

+94
-59
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,4 @@ The current development environment uses:
130130

131131
- pytest >= 5.3 (5.3.2)
132132
- black >= 19.1 (19.10b0)
133+
- mypy = 0.761

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[mypy]
2+
[mypy-numpy]
3+
ignore_missing_imports = True

slots/slots.py

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
mab.online_trial(bandit=1, payout=0)
1919
"""
2020

21+
from typing import Optional, List, Dict, Any, Union
2122

2223
import numpy as np
2324

@@ -29,12 +30,12 @@ class MAB(object):
2930

3031
def __init__(
3132
self,
32-
num_bandits=3,
33-
probs=None,
34-
hist_payouts=None,
35-
live=False,
36-
stop_criterion={"criterion": "regret", "value": 0.1},
37-
):
33+
num_bandits: Optional[int] = 3,
34+
probs: Optional[np.ndarray] = None,
35+
hist_payouts: Optional[List[np.ndarray]] = None,
36+
live: bool = False,
37+
stop_criterion: Optional[Dict] = {"criterion": "regret", "value": 0.1},
38+
) -> None:
3839
"""
3940
Parameters
4041
----------
@@ -51,7 +52,7 @@ def __init__(
5152
Stopping criterion (str) and threshold value (float).
5253
"""
5354

54-
self.choices = []
55+
self.choices: List[int] = []
5556

5657
if not probs:
5758
if not hist_payouts:
@@ -87,7 +88,7 @@ def __init__(
8788
print(
8889
"slots: Since historical payout data has been supplied, probabilities will be ignored."
8990
)
90-
if len(probs) == len(payouts):
91+
if len(probs) == len(hist_payouts):
9192
self.bandits = Bandits(
9293
hist_payouts=hist_payouts,
9394
live=False,
@@ -104,18 +105,32 @@ def __init__(
104105
probs=probs, payouts=np.zeros(num_bandits), live=False
105106
)
106107

107-
self.wins = np.zeros(num_bandits)
108-
self.pulls = np.zeros(num_bandits)
108+
self.wins: np.ndarray = np.zeros(num_bandits)
109+
self.pulls: np.ndarray = np.zeros(num_bandits)
109110

110111
# Set the stopping criteria
111-
self.criteria = {"regret": self.regret_met}
112-
self.criterion = stop_criterion.get("criterion", "regret")
113-
self.stop_value = stop_criterion.get("value", 0.1)
112+
self.criteria: Dict = {"regret": self.regret_met}
113+
if not stop_criterion:
114+
self.criterion: str = "regret"
115+
self.stop_value: float = 0.1
116+
else:
117+
self.criterion = stop_criterion.get("criterion", "regret")
118+
self.stop_value = stop_criterion.get("value", 0.1)
114119

115120
# Bandit selection strategies
116-
self.strategies = ["eps_greedy", "softmax", "ucb", "bayesian"]
121+
self.strategies: List[str] = [
122+
"eps_greedy",
123+
"softmax",
124+
"ucb",
125+
"bayesian",
126+
]
117127

118-
def run(self, trials=100, strategy="eps_greedy", parameters=None):
128+
def run(
129+
self,
130+
trials: int = 100,
131+
strategy: str = "eps_greedy",
132+
parameters: Optional[Dict] = None,
133+
) -> None:
119134
"""
120135
Run MAB test with T trials.
121136
@@ -154,31 +169,33 @@ def run(self, trials=100, strategy="eps_greedy", parameters=None):
154169
for n in range(trials):
155170
self._run(strategy, parameters)
156171

157-
def _run(self, strategy, parameters=None):
172+
def _run(self, strategy: str, parameters: Optional[Dict] = None) -> None:
158173
"""
159174
Run single trial of MAB strategy.
160175
161176
Parameters
162177
----------
163-
strategy : function
178+
strategy : str
164179
parameters : dict
165180
166181
Returns
167182
-------
168183
None
169184
"""
170185

171-
choice = self.run_strategy(strategy, parameters)
186+
choice: int = self.run_strategy(strategy, parameters)
172187
self.choices.append(choice)
173-
payout = self.bandits.pull(choice)
188+
payout: Optional[int] = self.bandits.pull(choice)
174189
if payout is None:
175190
print("Trials exhausted. No more values for bandit", choice)
176191
return None
177192
else:
178193
self.wins[choice] += payout
179194
self.pulls[choice] += 1
180195

181-
def run_strategy(self, strategy, parameters):
196+
def run_strategy(
197+
self, strategy: str, parameters: Optional[Dict] = None
198+
) -> int:
182199
"""
183200
Run the selected strategy and retrun bandit choice.
184201
@@ -198,7 +215,7 @@ def run_strategy(self, strategy, parameters):
198215
return self.__getattribute__(strategy)(params=parameters)
199216

200217
# ###### ----------- MAB strategies ---------------------------------------####
201-
def max_mean(self):
218+
def max_mean(self) -> int:
202219
"""
203220
Pick the bandit with the current best observed proportion of winning.
204221
@@ -210,7 +227,7 @@ def max_mean(self):
210227

211228
return np.argmax(self.wins / (self.pulls + 0.1))
212229

213-
def bayesian(self, params=None):
230+
def bayesian(self, params: Any = None) -> int:
214231
"""
215232
Run the Bayesian Bandit algorithm which utilizes a beta distribution
216233
for exploration and exploitation.
@@ -233,7 +250,7 @@ def bayesian(self, params=None):
233250

234251
return np.array(p_success_arms).argmax()
235252

236-
def eps_greedy(self, params):
253+
def eps_greedy(self, params: Optional[Dict] = None) -> int:
237254
"""
238255
Run the epsilon-greedy strategy and update self.max_mean()
239256
@@ -262,7 +279,7 @@ def eps_greedy(self, params):
262279
else:
263280
return self.max_mean()
264281

265-
def softmax(self, params):
282+
def softmax(self, params: Optional[Dict] = None) -> int:
266283
"""
267284
Run the softmax selection strategy.
268285
@@ -277,10 +294,10 @@ def softmax(self, params):
277294
Index of chosen bandit
278295
"""
279296

280-
default_tau = 0.1
297+
default_tau: float = 0.1
281298

282299
if params and type(params) == dict:
283-
tau = params.get("tau")
300+
tau: float = params.get("tau", default_tau)
284301
try:
285302
float(tau)
286303
except ValueError:
@@ -293,19 +310,19 @@ def softmax(self, params):
293310
if True in (self.pulls < 3):
294311
return np.random.choice(range(len(self.pulls)))
295312
else:
296-
payouts = self.wins / (self.pulls + 0.1)
297-
norm = sum(np.exp(payouts / tau))
313+
payouts: np.ndarray = self.wins / (self.pulls + 0.1)
314+
norm: float = sum(np.exp(payouts / tau))
298315

299-
ps = np.exp(payouts / tau) / norm
316+
ps: np.ndarray = np.exp(payouts / tau) / norm
300317

301318
# Randomly choose index based on CMF
302-
cmf = [sum(ps[: i + 1]) for i in range(len(ps))]
319+
cmf: List[int] = [sum(ps[: i + 1]) for i in range(len(ps))]
303320

304-
rand = np.random.rand()
321+
rand: float = np.random.rand()
305322

306-
found = False
307-
found_i = None
308-
i = 0
323+
found: bool = False
324+
found_i: int = 0
325+
i: int = 0
309326
while not found:
310327
if rand < cmf[i]:
311328
found_i = i
@@ -315,7 +332,7 @@ def softmax(self, params):
315332

316333
return found_i
317334

318-
def ucb(self, params=None):
335+
def ucb(self, params: Optional[Dict] = None) -> int:
319336
"""
320337
Run the upper confidence bound MAB selection strategy.
321338
@@ -340,15 +357,17 @@ def ucb(self, params=None):
340357
if True in (self.pulls < 3):
341358
return np.random.choice(range(len(self.pulls)))
342359
else:
343-
n_tot = sum(self.pulls)
344-
payouts = self.wins / (self.pulls + 0.1)
345-
ubcs = payouts + np.sqrt(2 * np.log(n_tot) / self.pulls)
360+
n_tot: int = sum(self.pulls)
361+
payouts: np.ndarray = self.wins / (self.pulls + 0.1)
362+
ubcs: np.ndarray = payouts + np.sqrt(
363+
2 * np.log(n_tot) / self.pulls
364+
)
346365

347366
return np.argmax(ubcs)
348367

349368
# ###------------------------------------------------------------------####
350369

351-
def best(self):
370+
def best(self) -> Optional[int]:
352371
"""
353372
Return current 'best' choice of bandit.
354373
@@ -364,7 +383,7 @@ def best(self):
364383
else:
365384
return np.argmax(self.wins / (self.pulls + 0.1))
366385

367-
def est_probs(self):
386+
def est_probs(self) -> Optional[np.ndarray]:
368387
"""
369388
Calculate current estimate of average payout for each bandit.
370389
@@ -379,7 +398,7 @@ def est_probs(self):
379398
else:
380399
return self.wins / (self.pulls + 0.1)
381400

382-
def regret(self):
401+
def regret(self) -> float:
383402
"""
384403
Calculate expected regret, where expected regret is
385404
maximum optimal reward - sum of collected rewards, i.e.
@@ -396,7 +415,7 @@ def regret(self):
396415
- sum(self.wins)
397416
) / sum(self.pulls)
398417

399-
def crit_met(self):
418+
def crit_met(self) -> bool:
400419
"""
401420
Determine if stopping criterion has been met.
402421
@@ -410,7 +429,7 @@ def crit_met(self):
410429
else:
411430
return self.criteria[self.criterion](self.stop_value)
412431

413-
def regret_met(self, threshold=None):
432+
def regret_met(self, threshold: Optional[float] = None) -> bool:
414433
"""
415434
Determine if regret criterion has been met.
416435
@@ -432,8 +451,12 @@ def regret_met(self, threshold=None):
432451

433452
# ## ------------ Online bandit testing ------------------------------ ####
434453
def online_trial(
435-
self, bandit=None, payout=None, strategy="eps_greedy", parameters=None
436-
):
454+
self,
455+
bandit: Optional[int] = None,
456+
payout: Optional[int] = None,
457+
strategy: str = "eps_greedy",
458+
parameters: Optional[Dict] = None,
459+
) -> Dict:
437460
"""
438461
Update the bandits with the results of the previous live, online trial.
439462
Next run a the selection algorithm. If the stopping criteria is
@@ -444,7 +467,7 @@ def online_trial(
444467
----------
445468
bandit : int
446469
Bandit index of most recent trial
447-
payout : float
470+
payout : int
448471
Payout value of most recent trial
449472
strategy : string
450473
Name of update strategy
@@ -477,7 +500,7 @@ def online_trial(
477500
"best": self.best(),
478501
}
479502

480-
def update(self, bandit, payout):
503+
def update(self, bandit, payout) -> None:
481504
"""
482505
Update bandit trials and payouts for given bandit.
483506
@@ -503,7 +526,13 @@ class Bandits:
503526
Bandit class.
504527
"""
505528

506-
def __init__(self, payouts, probs=None, hist_payouts=None, live=False):
529+
def __init__(
530+
self,
531+
payouts: np.ndarray,
532+
probs: Optional[np.ndarray] = None,
533+
hist_payouts: Optional[List[np.ndarray]] = None,
534+
live: bool = False,
535+
):
507536
"""
508537
Instantiate Bandit class, determining
509538
- Probabilities of bandit payouts
@@ -521,16 +550,16 @@ def __init__(self, payouts, probs=None, hist_payouts=None, live=False):
521550
"""
522551

523552
if not live:
524-
self.probs = probs
525-
self.payouts = payouts
526-
self.hist_payouts = hist_payouts
527-
self.live = False
553+
self.probs: Optional[np.ndarray] = probs
554+
self.payouts: np.ndarray = payouts
555+
self.hist_payouts: Optional[List[np.ndarray]] = hist_payouts
556+
self.live: bool = False
528557
else:
529558
self.live = True
530559
self.probs = None
531560
self.payouts = payouts
532561

533-
def pull(self, i):
562+
def pull(self, i: int) -> Optional[int]:
534563
"""
535564
Return the payout from a single pull of the bandit i's arm.
536565
@@ -541,7 +570,7 @@ def pull(self, i):
541570
542571
Returns
543572
-------
544-
float or None
573+
int or None
545574
"""
546575

547576
if self.live:
@@ -550,17 +579,19 @@ def pull(self, i):
550579
else:
551580
return None
552581
elif self.hist_payouts:
553-
if not hist[i]:
582+
if not self.hist_payouts[i]:
554583
return None
555584
else:
556-
_p = hist[i][0]
557-
hist[i] = hist[i][1:]
585+
_p = self.hist_payouts[i][0]
586+
self.hist_payouts[i] = self.hist_payouts[i][1:]
558587
return _p
559588
else:
560-
if np.random.rand() < self.probs[i]:
589+
if self.probs is None:
590+
return None
591+
elif np.random.rand() < self.probs[i]:
561592
return 1
562593
else:
563594
return 0
564595

565-
def info(self):
596+
def info(self) -> None:
566597
pass

0 commit comments

Comments
 (0)