@@ -23,8 +23,8 @@ class MAB(object):
2323 Multi-armed bandit test class.
2424 '''
2525
26- def __init__ (self , num_bandits = 3 , probs = None , payouts = None , live = False ,
27- stop_criterion = {'criterion' : 'regret' , 'value' : 1.0 }):
26+ def __init__ (self , num_bandits = 3 , probs = None , payouts = None , live = True ,
27+ stop_criterion = {'criterion' : 'regret' , 'value' : 0.1 }):
2828 '''
2929 Parameters
3030 ----------
@@ -33,37 +33,42 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=False,
3333 probs : np.array of floats
3434 payout probabilities
3535 payouts : np.array of floats
36- If `live` is True, `payouts` should be an N*T array of payout
37- amount per pull (floats) for N bandits and T trials
36+ If `live` is True, `payouts` should be None.
3837 live : bool
38+ Whether the use is for a live, online trial.
3939 stop_criterion : dict
4040 Stopping criterion (str) and threshold value (float).
4141 '''
4242
4343 self .choices = []
4444
4545 if not probs :
46- if payouts is None :
47- self .bandits = Bandits (probs = [np .random .rand () for x in
48- range (num_bandits )],
49- payouts = np .ones (num_bandits ))
50- else :
46+ if not payouts :
5147 if live :
52- self .bandits = Bandits (live = True , payouts = payouts ,
48+ self .bandits = Bandits (live = True ,
49+ payouts = np .zeros (num_bandits ),
5350 probs = None )
5451 else :
55- # Not sure why anyone would do this
5652 self .bandits = Bandits (probs = [np .random .rand () for x in
57- range (len (payouts ))],
58- payouts = payouts )
53+ range (num_bandits )],
54+ payouts = np .ones (num_bandits ),
55+ live = False )
56+ else :
57+
58+ self .bandits = Bandits (probs = [np .random .rand () for x in
59+ range (len (payouts ))],
60+ payouts = payouts ,
61+ live = False )
5962 num_bandits = len (payouts )
6063 else :
6164 if payouts :
62- self .bandits = Bandits (probs = probs , payouts = payouts )
65+ self .bandits = Bandits (probs = probs , payouts = payouts ,
66+ live = False )
6367 num_bandits = len (payouts )
6468 else :
6569 self .bandits = Bandits (probs = probs ,
66- payouts = np .ones (len (probs )))
70+ payouts = np .ones (len (probs )),
71+ live = False )
6772 num_bandits = len (probs )
6873
6974 self .wins = np .zeros (num_bandits )
@@ -361,7 +366,10 @@ def crit_met(self):
361366 bool
362367 '''
363368
364- return self .criteria [self .criterion ](self .stop_value )
369+ if True in (self .pulls < 3 ):
370+ return False
371+ else :
372+ return self .criteria [self .criterion ](self .stop_value )
365373
366374 def regret_met (self , threshold = None ):
367375 '''
@@ -409,7 +417,7 @@ def online_trial(self, bandit=None, payout=None, strategy='eps_greedy',
409417 Format: {'new_trial': boolean, 'choice': int, 'best': int}
410418 '''
411419
412- if bandit and payout :
420+ if bandit is not None and payout is not None :
413421 self .update (bandit = bandit , payout = payout )
414422 else :
415423 raise Exception ('slots.online_trial: bandit and/or payout value'
@@ -423,7 +431,7 @@ def online_trial(self, bandit=None, payout=None, strategy='eps_greedy',
423431 'choice' : self .run_strategy (strategy , parameters ),
424432 'best' : self .best ()}
425433
426- def update (self , bandit = None , payout = None ):
434+ def update (self , bandit , payout ):
427435 '''
428436 Update bandit trials and payouts for given bandit.
429437
@@ -438,6 +446,7 @@ def update(self, bandit=None, payout=None):
438446 None
439447 '''
440448
449+ self .choices .append (bandit )
441450 self .pulls [bandit ] += 1
442451 self .wins [bandit ] += payout
443452 self .bandits .payouts [bandit ] += payout
@@ -448,7 +457,7 @@ class Bandits():
448457 Bandit class.
449458 '''
450459
451- def __init__ (self , probs , payouts , live = False ):
460+ def __init__ (self , probs , payouts , live = True ):
452461 '''
453462 Instantiate Bandit class, determining
454463 - Probabilities of bandit payouts
@@ -460,8 +469,7 @@ def __init__(self, probs, payouts, live=False):
460469 Probabilities of bandit payouts
461470 payouts : array of floats
462471 Amount of bandit payouts. If `live` is True, `payouts` should be an
463- N*T array of floats giving the payout amount per pull for N bandits
464- and T trials.
472+ N length array of zeros.
465473 live : bool
466474 '''
467475
0 commit comments