1818 mab.online_trial(bandit=1, payout=0)
1919"""
2020
21+ from typing import Optional , List , Dict , Any , Union
2122
2223import numpy as np
2324
@@ -29,12 +30,12 @@ class MAB(object):
2930
3031 def __init__ (
3132 self ,
32- num_bandits = 3 ,
33- probs = None ,
34- hist_payouts = None ,
35- live = False ,
36- stop_criterion = {"criterion" : "regret" , "value" : 0.1 },
37- ):
33+ num_bandits : Optional [ int ] = 3 ,
34+ probs : Optional [ np . ndarray ] = None ,
35+ hist_payouts : Optional [ List [ np . ndarray ]] = None ,
36+ live : bool = False ,
37+ stop_criterion : Optional [ Dict ] = {"criterion" : "regret" , "value" : 0.1 },
38+ ) -> None :
3839 """
3940 Parameters
4041 ----------
@@ -51,7 +52,7 @@ def __init__(
5152 Stopping criterion (str) and threshold value (float).
5253 """
5354
54- self .choices = []
55+ self .choices : List [ int ] = []
5556
5657 if not probs :
5758 if not hist_payouts :
@@ -87,7 +88,7 @@ def __init__(
8788 print (
8889 "slots: Since historical payout data has been supplied, probabilities will be ignored."
8990 )
90- if len (probs ) == len (payouts ):
91+ if len (probs ) == len (hist_payouts ):
9192 self .bandits = Bandits (
9293 hist_payouts = hist_payouts ,
9394 live = False ,
@@ -104,18 +105,32 @@ def __init__(
104105 probs = probs , payouts = np .zeros (num_bandits ), live = False
105106 )
106107
107- self .wins = np .zeros (num_bandits )
108- self .pulls = np .zeros (num_bandits )
108+ self .wins : np . ndarray = np .zeros (num_bandits )
109+ self .pulls : np . ndarray = np .zeros (num_bandits )
109110
110111 # Set the stopping criteria
111- self .criteria = {"regret" : self .regret_met }
112- self .criterion = stop_criterion .get ("criterion" , "regret" )
113- self .stop_value = stop_criterion .get ("value" , 0.1 )
112+ self .criteria : Dict = {"regret" : self .regret_met }
113+ if not stop_criterion :
114+ self .criterion : str = "regret"
115+ self .stop_value : float = 0.1
116+ else :
117+ self .criterion = stop_criterion .get ("criterion" , "regret" )
118+ self .stop_value = stop_criterion .get ("value" , 0.1 )
114119
115120 # Bandit selection strategies
116- self .strategies = ["eps_greedy" , "softmax" , "ucb" , "bayesian" ]
121+ self .strategies : List [str ] = [
122+ "eps_greedy" ,
123+ "softmax" ,
124+ "ucb" ,
125+ "bayesian" ,
126+ ]
117127
118- def run (self , trials = 100 , strategy = "eps_greedy" , parameters = None ):
128+ def run (
129+ self ,
130+ trials : int = 100 ,
131+ strategy : str = "eps_greedy" ,
132+ parameters : Optional [Dict ] = None ,
133+ ) -> None :
119134 """
120135 Run MAB test with T trials.
121136
@@ -154,31 +169,33 @@ def run(self, trials=100, strategy="eps_greedy", parameters=None):
154169 for n in range (trials ):
155170 self ._run (strategy , parameters )
156171
157- def _run (self , strategy , parameters = None ):
172+ def _run (self , strategy : str , parameters : Optional [ Dict ] = None ) -> None :
158173 """
159174 Run single trial of MAB strategy.
160175
161176 Parameters
162177 ----------
163- strategy : function
178+ strategy : str
164179 parameters : dict
165180
166181 Returns
167182 -------
168183 None
169184 """
170185
171- choice = self .run_strategy (strategy , parameters )
186+ choice : int = self .run_strategy (strategy , parameters )
172187 self .choices .append (choice )
173- payout = self .bandits .pull (choice )
188+ payout : Optional [ int ] = self .bandits .pull (choice )
174189 if payout is None :
175190 print ("Trials exhausted. No more values for bandit" , choice )
176191 return None
177192 else :
178193 self .wins [choice ] += payout
179194 self .pulls [choice ] += 1
180195
181- def run_strategy (self , strategy , parameters ):
196+ def run_strategy (
197+ self , strategy : str , parameters : Optional [Dict ] = None
198+ ) -> int :
182199 """
183200 Run the selected strategy and retrun bandit choice.
184201
@@ -198,7 +215,7 @@ def run_strategy(self, strategy, parameters):
198215 return self .__getattribute__ (strategy )(params = parameters )
199216
200217 # ###### ----------- MAB strategies ---------------------------------------####
201- def max_mean (self ):
218+ def max_mean (self ) -> int :
202219 """
203220 Pick the bandit with the current best observed proportion of winning.
204221
@@ -210,7 +227,7 @@ def max_mean(self):
210227
211228 return np .argmax (self .wins / (self .pulls + 0.1 ))
212229
213- def bayesian (self , params = None ):
230+ def bayesian (self , params : Any = None ) -> int :
214231 """
215232 Run the Bayesian Bandit algorithm which utilizes a beta distribution
216233 for exploration and exploitation.
@@ -233,7 +250,7 @@ def bayesian(self, params=None):
233250
234251 return np .array (p_success_arms ).argmax ()
235252
236- def eps_greedy (self , params ) :
253+ def eps_greedy (self , params : Optional [ Dict ] = None ) -> int :
237254 """
238255 Run the epsilon-greedy strategy and update self.max_mean()
239256
@@ -262,7 +279,7 @@ def eps_greedy(self, params):
262279 else :
263280 return self .max_mean ()
264281
265- def softmax (self , params ) :
282+ def softmax (self , params : Optional [ Dict ] = None ) -> int :
266283 """
267284 Run the softmax selection strategy.
268285
@@ -277,10 +294,10 @@ def softmax(self, params):
277294 Index of chosen bandit
278295 """
279296
280- default_tau = 0.1
297+ default_tau : float = 0.1
281298
282299 if params and type (params ) == dict :
283- tau = params .get ("tau" )
300+ tau : float = params .get ("tau" , default_tau )
284301 try :
285302 float (tau )
286303 except ValueError :
@@ -293,19 +310,19 @@ def softmax(self, params):
293310 if True in (self .pulls < 3 ):
294311 return np .random .choice (range (len (self .pulls )))
295312 else :
296- payouts = self .wins / (self .pulls + 0.1 )
297- norm = sum (np .exp (payouts / tau ))
313+ payouts : np . ndarray = self .wins / (self .pulls + 0.1 )
314+ norm : float = sum (np .exp (payouts / tau ))
298315
299- ps = np .exp (payouts / tau ) / norm
316+ ps : np . ndarray = np .exp (payouts / tau ) / norm
300317
301318 # Randomly choose index based on CMF
302- cmf = [sum (ps [: i + 1 ]) for i in range (len (ps ))]
319+ cmf : List [ int ] = [sum (ps [: i + 1 ]) for i in range (len (ps ))]
303320
304- rand = np .random .rand ()
321+ rand : float = np .random .rand ()
305322
306- found = False
307- found_i = None
308- i = 0
323+ found : bool = False
324+ found_i : int = 0
325+ i : int = 0
309326 while not found :
310327 if rand < cmf [i ]:
311328 found_i = i
@@ -315,7 +332,7 @@ def softmax(self, params):
315332
316333 return found_i
317334
318- def ucb (self , params = None ):
335+ def ucb (self , params : Optional [ Dict ] = None ) -> int :
319336 """
320337 Run the upper confidence bound MAB selection strategy.
321338
@@ -340,15 +357,17 @@ def ucb(self, params=None):
340357 if True in (self .pulls < 3 ):
341358 return np .random .choice (range (len (self .pulls )))
342359 else :
343- n_tot = sum (self .pulls )
344- payouts = self .wins / (self .pulls + 0.1 )
345- ubcs = payouts + np .sqrt (2 * np .log (n_tot ) / self .pulls )
360+ n_tot : int = sum (self .pulls )
361+ payouts : np .ndarray = self .wins / (self .pulls + 0.1 )
362+ ubcs : np .ndarray = payouts + np .sqrt (
363+ 2 * np .log (n_tot ) / self .pulls
364+ )
346365
347366 return np .argmax (ubcs )
348367
349368 # ###------------------------------------------------------------------####
350369
351- def best (self ):
370+ def best (self ) -> Optional [ int ] :
352371 """
353372 Return current 'best' choice of bandit.
354373
@@ -364,7 +383,7 @@ def best(self):
364383 else :
365384 return np .argmax (self .wins / (self .pulls + 0.1 ))
366385
367- def est_probs (self ):
386+ def est_probs (self ) -> Optional [ np . ndarray ] :
368387 """
369388 Calculate current estimate of average payout for each bandit.
370389
@@ -379,7 +398,7 @@ def est_probs(self):
379398 else :
380399 return self .wins / (self .pulls + 0.1 )
381400
382- def regret (self ):
401+ def regret (self ) -> float :
383402 """
384403 Calculate expected regret, where expected regret is
385404 maximum optimal reward - sum of collected rewards, i.e.
@@ -396,7 +415,7 @@ def regret(self):
396415 - sum (self .wins )
397416 ) / sum (self .pulls )
398417
399- def crit_met (self ):
418+ def crit_met (self ) -> bool :
400419 """
401420 Determine if stopping criterion has been met.
402421
@@ -410,7 +429,7 @@ def crit_met(self):
410429 else :
411430 return self .criteria [self .criterion ](self .stop_value )
412431
413- def regret_met (self , threshold = None ):
432+ def regret_met (self , threshold : Optional [ float ] = None ) -> bool :
414433 """
415434 Determine if regret criterion has been met.
416435
@@ -432,8 +451,12 @@ def regret_met(self, threshold=None):
432451
433452 # ## ------------ Online bandit testing ------------------------------ ####
434453 def online_trial (
435- self , bandit = None , payout = None , strategy = "eps_greedy" , parameters = None
436- ):
454+ self ,
455+ bandit : Optional [int ] = None ,
456+ payout : Optional [int ] = None ,
457+ strategy : str = "eps_greedy" ,
458+ parameters : Optional [Dict ] = None ,
459+ ) -> Dict :
437460 """
438461 Update the bandits with the results of the previous live, online trial.
439462 Next run a the selection algorithm. If the stopping criteria is
@@ -444,7 +467,7 @@ def online_trial(
444467 ----------
445468 bandit : int
446469 Bandit index of most recent trial
447- payout : float
470+ payout : int
448471 Payout value of most recent trial
449472 strategy : string
450473 Name of update strategy
@@ -477,7 +500,7 @@ def online_trial(
477500 "best" : self .best (),
478501 }
479502
480- def update (self , bandit , payout ):
503+ def update (self , bandit , payout ) -> None :
481504 """
482505 Update bandit trials and payouts for given bandit.
483506
@@ -503,7 +526,13 @@ class Bandits:
503526 Bandit class.
504527 """
505528
506- def __init__ (self , payouts , probs = None , hist_payouts = None , live = False ):
529+ def __init__ (
530+ self ,
531+ payouts : np .ndarray ,
532+ probs : Optional [np .ndarray ] = None ,
533+ hist_payouts : Optional [List [np .ndarray ]] = None ,
534+ live : bool = False ,
535+ ):
507536 """
508537 Instantiate Bandit class, determining
509538 - Probabilities of bandit payouts
@@ -521,16 +550,16 @@ def __init__(self, payouts, probs=None, hist_payouts=None, live=False):
521550 """
522551
523552 if not live :
524- self .probs = probs
525- self .payouts = payouts
526- self .hist_payouts = hist_payouts
527- self .live = False
553+ self .probs : Optional [ np . ndarray ] = probs
554+ self .payouts : np . ndarray = payouts
555+ self .hist_payouts : Optional [ List [ np . ndarray ]] = hist_payouts
556+ self .live : bool = False
528557 else :
529558 self .live = True
530559 self .probs = None
531560 self .payouts = payouts
532561
533- def pull (self , i ) :
562+ def pull (self , i : int ) -> Optional [ int ] :
534563 """
535564 Return the payout from a single pull of the bandit i's arm.
536565
@@ -541,7 +570,7 @@ def pull(self, i):
541570
542571 Returns
543572 -------
544- float or None
573+ int or None
545574 """
546575
547576 if self .live :
@@ -550,17 +579,19 @@ def pull(self, i):
550579 else :
551580 return None
552581 elif self .hist_payouts :
553- if not hist [i ]:
582+ if not self . hist_payouts [i ]:
554583 return None
555584 else :
556- _p = hist [i ][0 ]
557- hist [i ] = hist [i ][1 :]
585+ _p = self . hist_payouts [i ][0 ]
586+ self . hist_payouts [i ] = self . hist_payouts [i ][1 :]
558587 return _p
559588 else :
560- if np .random .rand () < self .probs [i ]:
589+ if self .probs is None :
590+ return None
591+ elif np .random .rand () < self .probs [i ]:
561592 return 1
562593 else :
563594 return 0
564595
565- def info (self ):
596+ def info (self ) -> None :
566597 pass
0 commit comments