Skip to content

Commit 0f69159

Browse files
committed
Merge branch PR #3
2 parents 471b075 + 7ad4b70 commit 0f69159

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,22 @@ probs = [0.4, 0.9, 0.8]
7575
ba = slots.MAB(probs=probs)
7676
bb = slots.MAB(probs=probs)
7777
bc = slots.MAB(probs=probs)
78+
bd = slots.MAB(probs=probs)
7879

7980
# Run trials and calculate the regret after each trial
8081
rega = []
8182
regb = []
8283
regc = []
84+
regd = []
8385
for t in range(10000):
8486
ba._run('eps_greedy')
8587
rega.append(ba.regret())
8688
bb._run('softmax')
8789
regb.append(bb.regret())
8890
bc._run('ucb')
8991
regc.append(bc.regret())
92+
bd._run('bayesian_bandit')
93+
regd.append(bd.regret())
9094

9195

9296
# Pretty plotting
@@ -97,6 +101,7 @@ plt.figure(figsize=(15,4))
97101
plt.plot(rega, label='$\epsilon$-greedy ($\epsilon$=0.1)')
98102
plt.plot(regb, label='Softmax ($T$=0.1)')
99103
plt.plot(regc, label='UCB')
104+
plt.plot(regd, label='Bayesian Bandit')
100105
plt.legend()
101106
plt.xlabel('Trials')
102107
plt.ylabel('Regret')

misc/regret_plot.png

9.43 KB
Loading

slots/slots.py

100644100755
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=False,
7575
self.stop_value = stop_criterion.get('value', 0.1)
7676

7777
# Bandit selection strategies
78-
self.strategies = ['eps_greedy', 'softmax', 'ucb']
78+
self.strategies = ['eps_greedy', 'softmax', 'ucb', 'bayesian_bandit']
7979

8080
def run(self, trials=100, strategy=None, parameters=None):
8181
'''
@@ -169,6 +169,15 @@ def max_mean(self):
169169

170170
return np.argmax(self.wins / (self.pulls + 0.1))
171171

172+
def bayesian_bandit(self, params):
173+
'''
174+
Run the Bayesian Bandit algorithm which utilizes a beta distribution for exploration and exploitation.
175+
:param params:
176+
:return:
177+
'''
178+
p_success_arms = [np.random.beta(self.wins[i] + 1, self.pulls[i] - self.wins[i] + 1) for i in range(len(self.wins))]
179+
return np.array(p_success_arms).argmax()
180+
172181
def eps_greedy(self, params):
173182
'''
174183
Run the epsilon-greedy strategy and update self.max_mean()

0 commit comments

Comments
 (0)