@@ -76,20 +76,20 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
7676 self .pulls = np .zeros (num_bandits )
7777
7878 # Set the stopping criteria
79- self .criteria = {'regret' : self .regret_met () }
79+ self .criteria = {'regret' : self .regret_met }
8080 if stop_criterion .get ('criterion' ) in self .criteria :
8181 self .criterion = stop_criterion ['criterion' ]
8282 if stop_criterion .get ('value' ):
8383 self .stop_value = stop_criterion ['value' ]
8484 else :
8585 self .criterion = 'regret'
86- self .stop_value = 1.0
86+ self .stop_value = 0.1
8787
8888 def run (self , trials = 100 , strategy = None , parameters = None ):
8989 '''
9090 Run MAB test with T trials.
9191
92- Paramters :
92+ Parameters :
9393 trials (integer) - number of trials to run.
9494 strategy (string) - name of selected strategy.
9595 parameters (dict) - parameters for selected strategy.
@@ -112,15 +112,29 @@ def run(self, trials=100, strategy=None, parameters=None):
112112
113113 # Run strategy
114114 for n in range (trials ):
115- choice = strategies [strategy ](params = parameters )
116- self .choices .append (choice )
117- payout = self .bandits .pull (choice )
118- if payout is None :
119- print ('Trials exhausted. No more values for bandit' , choice )
120- break
121- else :
122- self .wins [choice ] += payout
123- self .pulls [choice ] += 1
115+ self ._run (strategies [strategy ], parameters )
116+
117+ def _run (self , strategy , parameters = None ):
118+ '''
119+ Run single trial of MAB strategy.
120+
121+ Input:
122+ stategy - function
123+ parameters - dictionary
124+
125+ Output:
126+ None
127+ '''
128+
129+ choice = strategy (params = parameters )
130+ self .choices .append (choice )
131+ payout = self .bandits .pull (choice )
132+ if payout is None :
133+ print ('Trials exhausted. No more values for bandit' , choice )
134+ return None
135+ else :
136+ self .wins [choice ] += payout
137+ self .pulls [choice ] += 1
124138
125139
126140# ###### ----------- MAB strategies ---------------------------------------####
@@ -259,7 +273,7 @@ def regret(self):
259273 Output: float
260274 '''
261275
262- return (sum (self .pulls )* np .max (self .wins / self .pulls ) -
276+ return (sum (self .pulls )* np .max (np . nan_to_num ( self .wins / self .pulls ) ) -
263277 sum (self .wins )) / sum (self .pulls )
264278
265279 def crit_met (self ):
@@ -280,7 +294,7 @@ def regret_met(self, threshold=None):
280294 '''
281295
282296 if not threshold :
283- return False
297+ return self . regret () <= self . stop_value
284298 elif self .regret () <= threshold :
285299 return True
286300 else :
0 commit comments