Skip to content

Commit ec5a8ec

Browse files
committed
Changed default settings to "live first" (online is default use case)
Fixed instatiation of n online bandits
1 parent 7e2d36b commit ec5a8ec

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

slots/slots.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class MAB(object):
2323
Multi-armed bandit test class.
2424
'''
2525

26-
def __init__(self, num_bandits=3, probs=None, payouts=None, live=False,
27-
stop_criterion={'criterion': 'regret', 'value': 1.0}):
26+
def __init__(self, num_bandits=3, probs=None, payouts=None, live=True,
27+
stop_criterion={'criterion': 'regret', 'value': 0.1}):
2828
'''
2929
Parameters
3030
----------
@@ -33,37 +33,42 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=False,
3333
probs : np.array of floats
3434
payout probabilities
3535
payouts : np.array of floats
36-
If `live` is True, `payouts` should be an N*T array of payout
37-
amount per pull (floats) for N bandits and T trials
36+
If `live` is True, `payouts` should be None.
3837
live : bool
38+
Whether the use is for a live, online trial.
3939
stop_criterion : dict
4040
Stopping criterion (str) and threshold value (float).
4141
'''
4242

4343
self.choices = []
4444

4545
if not probs:
46-
if payouts is None:
47-
self.bandits = Bandits(probs=[np.random.rand() for x in
48-
range(num_bandits)],
49-
payouts=np.ones(num_bandits))
50-
else:
46+
if not payouts:
5147
if live:
52-
self.bandits = Bandits(live=True, payouts=payouts,
48+
self.bandits = Bandits(live=True,
49+
payouts=np.zeros(num_bandits),
5350
probs=None)
5451
else:
55-
# Not sure why anyone would do this
5652
self.bandits = Bandits(probs=[np.random.rand() for x in
57-
range(len(payouts))],
58-
payouts=payouts)
53+
range(num_bandits)],
54+
payouts=np.ones(num_bandits),
55+
live=False)
56+
else:
57+
58+
self.bandits = Bandits(probs=[np.random.rand() for x in
59+
range(len(payouts))],
60+
payouts=payouts,
61+
live=False)
5962
num_bandits = len(payouts)
6063
else:
6164
if payouts:
62-
self.bandits = Bandits(probs=probs, payouts=payouts)
65+
self.bandits = Bandits(probs=probs, payouts=payouts,
66+
live=False)
6367
num_bandits = len(payouts)
6468
else:
6569
self.bandits = Bandits(probs=probs,
66-
payouts=np.ones(len(probs)))
70+
payouts=np.ones(len(probs)),
71+
live=False)
6772
num_bandits = len(probs)
6873

6974
self.wins = np.zeros(num_bandits)
@@ -361,7 +366,10 @@ def crit_met(self):
361366
bool
362367
'''
363368

364-
return self.criteria[self.criterion](self.stop_value)
369+
if True in (self.pulls < 3):
370+
return False
371+
else:
372+
return self.criteria[self.criterion](self.stop_value)
365373

366374
def regret_met(self, threshold=None):
367375
'''
@@ -409,7 +417,7 @@ def online_trial(self, bandit=None, payout=None, strategy='eps_greedy',
409417
Format: {'new_trial': boolean, 'choice': int, 'best': int}
410418
'''
411419

412-
if bandit and payout:
420+
if bandit is not None and payout is not None:
413421
self.update(bandit=bandit, payout=payout)
414422
else:
415423
raise Exception('slots.online_trial: bandit and/or payout value'
@@ -423,7 +431,7 @@ def online_trial(self, bandit=None, payout=None, strategy='eps_greedy',
423431
'choice': self.run_strategy(strategy, parameters),
424432
'best': self.best()}
425433

426-
def update(self, bandit=None, payout=None):
434+
def update(self, bandit, payout):
427435
'''
428436
Update bandit trials and payouts for given bandit.
429437
@@ -438,6 +446,7 @@ def update(self, bandit=None, payout=None):
438446
None
439447
'''
440448

449+
self.choices.append(bandit)
441450
self.pulls[bandit] += 1
442451
self.wins[bandit] += payout
443452
self.bandits.payouts[bandit] += payout
@@ -448,7 +457,7 @@ class Bandits():
448457
Bandit class.
449458
'''
450459

451-
def __init__(self, probs, payouts, live=False):
460+
def __init__(self, probs, payouts, live=True):
452461
'''
453462
Instantiate Bandit class, determining
454463
- Probabilities of bandit payouts
@@ -460,8 +469,7 @@ def __init__(self, probs, payouts, live=False):
460469
Probabilities of bandit payouts
461470
payouts : array of floats
462471
Amount of bandit payouts. If `live` is True, `payouts` should be an
463-
N*T array of floats giving the payout amount per pull for N bandits
464-
and T trials.
472+
N length array of zeros.
465473
live : bool
466474
'''
467475

0 commit comments

Comments
 (0)