@@ -69,7 +69,7 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
6969 num_bandits = len (payouts )
7070 else :
7171 self .bandits = Bandits (probs = probs ,
72- payouts = np .ones (len (payouts )))
72+ payouts = np .ones (len (probs )))
7373 num_bandits = len (probs )
7474
7575 self .wins = np .zeros (num_bandits )
@@ -85,6 +85,9 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
8585 self .criterion = 'regret'
8686 self .stop_value = 0.1
8787
88+ # Bandit selection strategies
89+ self .strategies = ['eps_greedy' , 'softmax' , 'ucb' ]
90+
8891 def run (self , trials = 100 , strategy = None , parameters = None ):
8992 '''
9093 Run MAB test with T trials.
@@ -94,25 +97,24 @@ def run(self, trials=100, strategy=None, parameters=None):
9497 strategy (string) - name of selected strategy.
9598 parameters (dict) - parameters for selected strategy.
9699
97- Currently on epsilon greedy is implemented.
100+ Available strategies:
101+ - Epsilon-greedy ("eps_greedy")
102+ - Softmax ("softmax")
103+ - Upper credibility bound ("ucb")
98104 '''
99105
100- strategies = {'eps_greedy' : self .eps_greedy ,
101- 'softmax' : self .softmax ,
102- 'ucb' : self .ucb }
103-
104106 if trials < 1 :
105107 raise Exception ('MAB.run: Number of trials cannot be less than 1!' )
106108 if not strategy :
107109 strategy = 'eps_greedy'
108110 else :
109- if strategy not in strategies :
111+ if strategy not in self . strategies :
110112 raise Exception ('MAB,run: Strategy name invalid. Choose from:'
111- ' {}' .format (', ' .join (strategies )))
113+ ' {}' .format (', ' .join (self . strategies )))
112114
113115 # Run strategy
114116 for n in range (trials ):
115- self ._run (strategies [ strategy ] , parameters )
117+ self ._run (strategy , parameters )
116118
117119 def _run (self , strategy , parameters = None ):
118120 '''
@@ -126,7 +128,7 @@ def _run(self, strategy, parameters=None):
126128 None
127129 '''
128130
129- choice = strategy ( params = parameters )
131+ choice = self . run_strategy ( strategy , parameters )
130132 self .choices .append (choice )
131133 payout = self .bandits .pull (choice )
132134 if payout is None :
@@ -136,6 +138,19 @@ def _run(self, strategy, parameters=None):
136138 self .wins [choice ] += payout
137139 self .pulls [choice ] += 1
138140
141+ def run_strategy (self , strategy , parameters ):
142+ '''
143+ Run the selected strategy and retrun bandit choice.
144+
145+ Input:
146+ strategy - string of strategy name
147+ parameters - dict of strategy function parameters
148+
149+ Output:
150+ integer. Call strategy function, which returns bandit arm choice.
151+ '''
152+
153+ return self .__getattribute__ (strategy )(params = parameters )
139154
140155# ###### ----------- MAB strategies ---------------------------------------####
141156 def max_mean (self ):
@@ -145,6 +160,7 @@ def max_mean(self):
145160 Input: self
146161 Output: int (index of chosen bandit)
147162 """
163+
148164 return np .argmax (self .wins / (self .pulls + 0.1 ))
149165
150166 def eps_greedy (self , params ):
@@ -161,6 +177,7 @@ def eps_greedy(self, params):
161177 eps = 0.1
162178
163179 r = np .random .rand ()
180+
164181 if r < eps :
165182 return np .random .choice (list (set (range (len (self .wins ))) -
166183 {self .max_mean ()}))
@@ -175,7 +192,7 @@ def softmax(self, params):
175192 Output: int (index of chosen bandit)
176193 '''
177194
178- default_tau = 1.0
195+ default_tau = 0.1
179196
180197 if params and type (params ) == dict :
181198 tau = params .get ('tau' )
@@ -189,7 +206,7 @@ def softmax(self, params):
189206
190207 # Handle cold start. Not all bandits tested yet.
191208 if True in (self .pulls < 3 ):
192- return np .random .choice (xrange (len (self .pulls )))
209+ return np .random .choice (range (len (self .pulls )))
193210 else :
194211 payouts = self .wins / (self .pulls + 0.1 )
195212 norm = sum (np .exp (payouts / tau ))
@@ -225,7 +242,7 @@ def ucb(self, params=None):
225242
226243 # Handle cold start. Not all bandits tested yet.
227244 if True in (self .pulls < 3 ):
228- return np .random .choice (xrange (len (self .pulls )))
245+ return np .random .choice (range (len (self .pulls )))
229246 else :
230247 n_tot = sum (self .pulls )
231248 payouts = self .wins / (self .pulls + 0.1 )
@@ -301,16 +318,22 @@ def regret_met(self, threshold=None):
301318 return False
302319
303320 # ## ------------ Online bandit testing ------------------------------ ####
304- def online_trial (self , bandit = None , payout = None ):
321+ def online_trial (self , bandit = None , payout = None , strategy = 'eps_greedy' ,
322+ parameters = None ):
305323 '''
306324 Update the bandits with the results of the previous live, online trial.
307325 Next run a the selection algorithm. If the stopping criteria is
308326 met, return the best arm estimate. Otherwise return the next arm to
309327 try.
310328
311- Input: int (bandit to update), float (payout to update for bandit)
312- Output: dict
313- format: {'new_trial': boolean, 'choice': int, 'best': int}
329+ Input:
330+ bandit - int of bandit index
331+ payout - float of payout value
332+ strategy - string name of update strategy
333+ parameters - dict of parameters for update strategy function
334+
335+ Output:
336+ dict - format: {'new_trial': boolean, 'choice': int, 'best': int}
314337 '''
315338
316339 if bandit and payout :
@@ -323,11 +346,9 @@ def online_trial(self, bandit=None, payout=None):
323346 return {'new_trial' : False , 'choice' : self .best (),
324347 'best' : self .best ()}
325348 else :
326- # TODO: implement choice via strategy
327- print ('slots: online trial strategy not yet implemented.'
328- ' No trial run.' )
329- # choice = self.run(trial=1)
330- pass
349+ return {'new_trial' : True ,
350+ 'choice' : self .run_strategy (strategy , parameters ),
351+ 'best' : self .best ()}
331352
332353 def update (self , bandit = None , payout = None ):
333354 '''
0 commit comments