Skip to content

Commit f12145d

Browse files
committed
Clean up and refactoring.
Fixed regret NaN issue when arm has no pulls.
1 parent 0d1d42f commit f12145d

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

slots.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,20 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
7676
self.pulls = np.zeros(num_bandits)
7777

7878
# Set the stopping criteria
79-
self.criteria = {'regret': self.regret_met()}
79+
self.criteria = {'regret': self.regret_met}
8080
if stop_criterion.get('criterion') in self.criteria:
8181
self.criterion = stop_criterion['criterion']
8282
if stop_criterion.get('value'):
8383
self.stop_value = stop_criterion['value']
8484
else:
8585
self.criterion = 'regret'
86-
self.stop_value = 1.0
86+
self.stop_value = 0.1
8787

8888
def run(self, trials=100, strategy=None, parameters=None):
8989
'''
9090
Run MAB test with T trials.
9191
92-
Paramters:
92+
Parameters:
9393
trials (integer) - number of trials to run.
9494
strategy (string) - name of selected strategy.
9595
parameters (dict) - parameters for selected strategy.
@@ -112,15 +112,29 @@ def run(self, trials=100, strategy=None, parameters=None):
112112

113113
# Run strategy
114114
for n in range(trials):
115-
choice = strategies[strategy](params=parameters)
116-
self.choices.append(choice)
117-
payout = self.bandits.pull(choice)
118-
if payout is None:
119-
print('Trials exhausted. No more values for bandit', choice)
120-
break
121-
else:
122-
self.wins[choice] += payout
123-
self.pulls[choice] += 1
115+
self._run(strategies[strategy], parameters)
116+
117+
def _run(self, strategy, parameters=None):
118+
'''
119+
Run single trial of MAB strategy.
120+
121+
Input:
122+
stategy - function
123+
parameters - dictionary
124+
125+
Output:
126+
None
127+
'''
128+
129+
choice = strategy(params=parameters)
130+
self.choices.append(choice)
131+
payout = self.bandits.pull(choice)
132+
if payout is None:
133+
print('Trials exhausted. No more values for bandit', choice)
134+
return None
135+
else:
136+
self.wins[choice] += payout
137+
self.pulls[choice] += 1
124138

125139

126140
# ###### ----------- MAB strategies ---------------------------------------####
@@ -259,7 +273,7 @@ def regret(self):
259273
Output: float
260274
'''
261275

262-
return (sum(self.pulls)*np.max(self.wins/self.pulls) -
276+
return (sum(self.pulls)*np.max(np.nan_to_num(self.wins/self.pulls)) -
263277
sum(self.wins)) / sum(self.pulls)
264278

265279
def crit_met(self):
@@ -280,7 +294,7 @@ def regret_met(self, threshold=None):
280294
'''
281295

282296
if not threshold:
283-
return False
297+
return self.regret() <= self.stop_value
284298
elif self.regret() <= threshold:
285299
return True
286300
else:

0 commit comments

Comments
 (0)