Skip to content

Commit 7da8991

Browse files
committed
Update strategy selection code and online fixes.
Misc fixes.
1 parent f12145d commit 7da8991

File tree

1 file changed

+43
-22
lines changed

1 file changed

+43
-22
lines changed

slots.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
6969
num_bandits = len(payouts)
7070
else:
7171
self.bandits = Bandits(probs=probs,
72-
payouts=np.ones(len(payouts)))
72+
payouts=np.ones(len(probs)))
7373
num_bandits = len(probs)
7474

7575
self.wins = np.zeros(num_bandits)
@@ -85,6 +85,9 @@ def __init__(self, num_bandits=None, probs=None, payouts=None, live=False,
8585
self.criterion = 'regret'
8686
self.stop_value = 0.1
8787

88+
# Bandit selection strategies
89+
self.strategies = ['eps_greedy', 'softmax', 'ucb']
90+
8891
def run(self, trials=100, strategy=None, parameters=None):
8992
'''
9093
Run MAB test with T trials.
@@ -94,25 +97,24 @@ def run(self, trials=100, strategy=None, parameters=None):
9497
strategy (string) - name of selected strategy.
9598
parameters (dict) - parameters for selected strategy.
9699
97-
Currently on epsilon greedy is implemented.
100+
Available strategies:
101+
- Epsilon-greedy ("eps_greedy")
102+
- Softmax ("softmax")
103+
- Upper credibility bound ("ucb")
98104
'''
99105

100-
strategies = {'eps_greedy': self.eps_greedy,
101-
'softmax': self.softmax,
102-
'ucb': self.ucb}
103-
104106
if trials < 1:
105107
raise Exception('MAB.run: Number of trials cannot be less than 1!')
106108
if not strategy:
107109
strategy = 'eps_greedy'
108110
else:
109-
if strategy not in strategies:
111+
if strategy not in self.strategies:
110112
raise Exception('MAB,run: Strategy name invalid. Choose from:'
111-
' {}'.format(', '.join(strategies)))
113+
' {}'.format(', '.join(self.strategies)))
112114

113115
# Run strategy
114116
for n in range(trials):
115-
self._run(strategies[strategy], parameters)
117+
self._run(strategy, parameters)
116118

117119
def _run(self, strategy, parameters=None):
118120
'''
@@ -126,7 +128,7 @@ def _run(self, strategy, parameters=None):
126128
None
127129
'''
128130

129-
choice = strategy(params=parameters)
131+
choice = self.run_strategy(strategy, parameters)
130132
self.choices.append(choice)
131133
payout = self.bandits.pull(choice)
132134
if payout is None:
@@ -136,6 +138,19 @@ def _run(self, strategy, parameters=None):
136138
self.wins[choice] += payout
137139
self.pulls[choice] += 1
138140

141+
def run_strategy(self, strategy, parameters):
142+
'''
143+
Run the selected strategy and retrun bandit choice.
144+
145+
Input:
146+
strategy - string of strategy name
147+
parameters - dict of strategy function parameters
148+
149+
Output:
150+
integer. Call strategy function, which returns bandit arm choice.
151+
'''
152+
153+
return self.__getattribute__(strategy)(params=parameters)
139154

140155
# ###### ----------- MAB strategies ---------------------------------------####
141156
def max_mean(self):
@@ -145,6 +160,7 @@ def max_mean(self):
145160
Input: self
146161
Output: int (index of chosen bandit)
147162
"""
163+
148164
return np.argmax(self.wins / (self.pulls + 0.1))
149165

150166
def eps_greedy(self, params):
@@ -161,6 +177,7 @@ def eps_greedy(self, params):
161177
eps = 0.1
162178

163179
r = np.random.rand()
180+
164181
if r < eps:
165182
return np.random.choice(list(set(range(len(self.wins))) -
166183
{self.max_mean()}))
@@ -175,7 +192,7 @@ def softmax(self, params):
175192
Output: int (index of chosen bandit)
176193
'''
177194

178-
default_tau = 1.0
195+
default_tau = 0.1
179196

180197
if params and type(params) == dict:
181198
tau = params.get('tau')
@@ -189,7 +206,7 @@ def softmax(self, params):
189206

190207
# Handle cold start. Not all bandits tested yet.
191208
if True in (self.pulls < 3):
192-
return np.random.choice(xrange(len(self.pulls)))
209+
return np.random.choice(range(len(self.pulls)))
193210
else:
194211
payouts = self.wins / (self.pulls + 0.1)
195212
norm = sum(np.exp(payouts/tau))
@@ -225,7 +242,7 @@ def ucb(self, params=None):
225242

226243
# Handle cold start. Not all bandits tested yet.
227244
if True in (self.pulls < 3):
228-
return np.random.choice(xrange(len(self.pulls)))
245+
return np.random.choice(range(len(self.pulls)))
229246
else:
230247
n_tot = sum(self.pulls)
231248
payouts = self.wins / (self.pulls + 0.1)
@@ -301,16 +318,22 @@ def regret_met(self, threshold=None):
301318
return False
302319

303320
# ## ------------ Online bandit testing ------------------------------ ####
304-
def online_trial(self, bandit=None, payout=None):
321+
def online_trial(self, bandit=None, payout=None, strategy='eps_greedy',
322+
parameters=None):
305323
'''
306324
Update the bandits with the results of the previous live, online trial.
307325
Next run a the selection algorithm. If the stopping criteria is
308326
met, return the best arm estimate. Otherwise return the next arm to
309327
try.
310328
311-
Input: int (bandit to update), float (payout to update for bandit)
312-
Output: dict
313-
format: {'new_trial': boolean, 'choice': int, 'best': int}
329+
Input:
330+
bandit - int of bandit index
331+
payout - float of payout value
332+
strategy - string name of update strategy
333+
parameters - dict of parameters for update strategy function
334+
335+
Output:
336+
dict - format: {'new_trial': boolean, 'choice': int, 'best': int}
314337
'''
315338

316339
if bandit and payout:
@@ -323,11 +346,9 @@ def online_trial(self, bandit=None, payout=None):
323346
return {'new_trial': False, 'choice': self.best(),
324347
'best': self.best()}
325348
else:
326-
# TODO: implement choice via strategy
327-
print('slots: online trial strategy not yet implemented.'
328-
' No trial run.')
329-
# choice = self.run(trial=1)
330-
pass
349+
return {'new_trial': True,
350+
'choice': self.run_strategy(strategy, parameters),
351+
'best': self.best()}
331352

332353
def update(self, bandit=None, payout=None):
333354
'''

0 commit comments

Comments
 (0)