Skip to content

Commit c6881a2

Browse files
authored
Merge pull request #10 from sinanh/master
eps_greedy as default parameter and PEP8 fixes
2 parents e3c1a12 + 208c0d9 commit c6881a2

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

slots/slots.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=True,
5050
probs=None)
5151
else:
5252
self.bandits = Bandits(probs=[np.random.rand() for x in
53-
range(num_bandits)],
53+
range(num_bandits)],
5454
payouts=np.ones(num_bandits),
5555
live=False)
5656
else:
5757

5858
self.bandits = Bandits(probs=[np.random.rand() for x in
59-
range(len(payouts))],
59+
range(len(payouts))],
6060
payouts=payouts,
6161
live=False)
6262
num_bandits = len(payouts)
@@ -82,7 +82,7 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=True,
8282
# Bandit selection strategies
8383
self.strategies = ['eps_greedy', 'softmax', 'ucb', 'bayesian']
8484

85-
def run(self, trials=100, strategy=None, parameters=None):
85+
def run(self, trials=100, strategy='eps_greedy', parameters=None):
8686
'''
8787
Run MAB test with T trials.
8888
@@ -107,8 +107,7 @@ def run(self, trials=100, strategy=None, parameters=None):
107107

108108
if trials < 1:
109109
raise Exception('MAB.run: Number of trials cannot be less than 1!')
110-
if not strategy:
111-
strategy = 'eps_greedy'
110+
112111
else:
113112
if strategy not in self.strategies:
114113
raise Exception('MAB,run: Strategy name invalid. Choose from:'
@@ -193,7 +192,7 @@ def bayesian(self, params=None):
193192
p_success_arms = [
194193
np.random.beta(self.wins[i] + 1, self.pulls[i] - self.wins[i] + 1)
195194
for i in range(len(self.wins))
196-
]
195+
]
197196

198197
return np.array(p_success_arms).argmax()
199198

@@ -221,7 +220,7 @@ def eps_greedy(self, params):
221220

222221
if r < eps:
223222
return np.random.choice(list(set(range(len(self.wins))) -
224-
{self.max_mean()}))
223+
{self.max_mean()}))
225224
else:
226225
return self.max_mean()
227226

0 commit comments

Comments
 (0)