1818 mab.online_trial(bandit=1, payout=0)
1919"""
2020
21- from typing import Optional , List , Dict , Any , Union
21+ from typing import Optional , List , Dict , Any , Union , Callable
2222
2323import numpy as np
2424
@@ -33,7 +33,7 @@ def __init__(
3333 num_bandits : Optional [int ] = 3 ,
3434 probs : Optional [np .ndarray ] = None ,
3535 hist_payouts : Optional [List [np .ndarray ]] = None ,
36- live : bool = False ,
36+ live : Optional [ bool ] = False ,
3737 stop_criterion : Optional [Dict ] = {"criterion" : "regret" , "value" : 0.1 },
3838 ) -> None :
3939 """
@@ -59,7 +59,7 @@ def __init__(
5959 if live :
6060 # Live trial scenario, where nothing is known except the
6161 # number of bandits
62- self .bandits = Bandits (
62+ self .bandits : Bandits = Bandits (
6363 live = True , payouts = np .zeros (num_bandits )
6464 )
6565 else :
@@ -109,7 +109,9 @@ def __init__(
109109 self .pulls : np .ndarray = np .zeros (num_bandits )
110110
111111 # Set the stopping criteria
112- self .criteria : Dict = {"regret" : self .regret_met }
112+ self .criteria : Dict [str , Callable [[Optional [float ]], bool ]] = {
113+ "regret" : self .regret_met
114+ }
113115 if not stop_criterion :
114116 self .criterion : str = "regret"
115117 self .stop_value : float = 0.1
@@ -243,14 +245,14 @@ def bayesian(self, params: Any = None) -> int:
243245 int
244246 Index of chosen bandit
245247 """
246- p_success_arms = [
248+ p_success_arms : List [ float ] = [
247249 np .random .beta (self .wins [i ] + 1 , self .pulls [i ] - self .wins [i ] + 1 )
248250 for i in range (len (self .wins ))
249251 ]
250252
251253 return np .array (p_success_arms ).argmax ()
252254
253- def eps_greedy (self , params : Optional [Dict ] = None ) -> int :
255+ def eps_greedy (self , params : Optional [Dict [ str , float ] ] = None ) -> int :
254256 """
255257 Run the epsilon-greedy strategy and update self.max_mean()
256258
@@ -265,12 +267,19 @@ def eps_greedy(self, params: Optional[Dict] = None) -> int:
265267 Index of chosen bandit
266268 """
267269
270+ default_eps : float = 0.1
271+
268272 if params and type (params ) == dict :
269- eps = params .get ("epsilon" )
273+ eps : float = params .get ("epsilon" , default_eps )
274+ try :
275+ float (eps )
276+ except ValueError :
277+ print ("slots: eps_greedy: Setting eps to default" )
278+ eps = default_eps
270279 else :
271- eps = 0.1
280+ eps = default_eps
272281
273- r = np .random .rand ()
282+ r : int = np .random .rand ()
274283
275284 if r < eps :
276285 return np .random .choice (
@@ -301,7 +310,7 @@ def softmax(self, params: Optional[Dict] = None) -> int:
301310 try :
302311 float (tau )
303312 except ValueError :
304- "slots: softmax: Setting tau to default"
313+ print ( "slots: softmax: Setting tau to default" )
305314 tau = default_tau
306315 else :
307316 tau = default_tau
@@ -582,7 +591,7 @@ def pull(self, i: int) -> Optional[int]:
582591 if not self .hist_payouts [i ]:
583592 return None
584593 else :
585- _p = self .hist_payouts [i ][0 ]
594+ _p : int = self .hist_payouts [i ][0 ]
586595 self .hist_payouts [i ] = self .hist_payouts [i ][1 :]
587596 return _p
588597 else :
0 commit comments