Skip to content

Commit 0a74299

Browse files
committed
Add MCCFR
1 parent 555338e commit 0a74299

File tree

4 files changed

+349
-6
lines changed

4 files changed

+349
-6
lines changed

noregret/games.py

Lines changed: 178 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
2+
from collections import defaultdict
23
from dataclasses import dataclass
3-
from functools import partial
4+
from functools import cache, partial
45
from itertools import permutations
56
from math import factorial
67
from typing import Any
@@ -73,8 +74,11 @@ def nash_gap(self, *strategies):
7374

7475
for i, value in enumerate(self.values(*strategies)):
7576
opponent_strategies = strategies[:i] + strategies[i + 1:]
76-
_, br_value = self.best_response(i, *opponent_strategies)
77-
gap += br_value - value
77+
_, best_response_value = self.best_response(
78+
i,
79+
*opponent_strategies,
80+
)
81+
gap += best_response_value - value
7882

7983
return gap
8084

@@ -86,8 +90,11 @@ def cce_gap(self, *strategies):
8690
average_opponent_strategies = (
8791
average_strategies[:i] + average_strategies[i + 1:]
8892
)
89-
_, br_value = self.best_response(i, *average_opponent_strategies)
90-
gap += br_value - value
93+
_, best_response_value = self.best_response(
94+
i,
95+
*average_opponent_strategies,
96+
)
97+
gap += best_response_value - value
9198

9299
return gap
93100

@@ -613,3 +620,169 @@ def correlated_value(self, player, *strategies):
613620

614621
def best_response(self, player, *opponent_strategies):
615622
raise NotImplementedError
623+
624+
625+
class ExtensiveFormGame2(ABC):
626+
"""Extensive-form game (EFG)."""
627+
628+
@dataclass(frozen=True)
629+
class State:
630+
"""State of an extensive-form game."""
631+
632+
@property
633+
@abstractmethod
634+
def utilities(self):
635+
pass
636+
637+
@property
638+
@abstractmethod
639+
def chance_action_probabilities(self):
640+
pass
641+
642+
@property
643+
@abstractmethod
644+
def actions(self):
645+
pass
646+
647+
@property
648+
@abstractmethod
649+
def infoset(self):
650+
pass
651+
652+
@property
653+
@abstractmethod
654+
def player(self):
655+
pass
656+
657+
@abstractmethod
658+
def is_terminal(self):
659+
pass
660+
661+
@abstractmethod
662+
def is_chance(self):
663+
pass
664+
665+
@abstractmethod
666+
def utility(self, player):
667+
pass
668+
669+
@abstractmethod
670+
def apply(self, action):
671+
pass
672+
673+
@property
674+
@abstractmethod
675+
def players(self):
676+
pass
677+
678+
@property
679+
@abstractmethod
680+
def initial_state(self):
681+
pass
682+
683+
def values(self, strategy_profile, state=None):
684+
if state is None:
685+
values = self.values(strategy_profile, self.initial_state)
686+
elif state.is_terminal():
687+
values = state.utilities
688+
else:
689+
if state.is_chance():
690+
actions, probabilities = zip(
691+
*state.chance_action_probabilities,
692+
)
693+
else:
694+
actions = state.actions
695+
probabilities = strategy_profile(state)
696+
697+
values = 0
698+
699+
for action, probability in zip(actions, probabilities):
700+
values += (
701+
probability
702+
* self.values(strategy_profile, state.apply(action))
703+
)
704+
705+
return values
706+
707+
def best_response_value(self, player, strategy_profile):
708+
states = defaultdict(list)
709+
counterfactual_reach_probabilities = {}
710+
711+
def dfs(state, counterfactual_reach_probability):
712+
counterfactual_reach_probabilities[state] = (
713+
counterfactual_reach_probability
714+
)
715+
716+
if state.is_terminal():
717+
return
718+
719+
if not state.is_chance():
720+
states[state.infoset].append(state)
721+
722+
if state.is_chance() or state.player != player:
723+
if state.is_chance():
724+
actions, probabilities = zip(
725+
*state.chance_action_probabilities,
726+
)
727+
else:
728+
actions = state.actions
729+
probabilities = strategy_profile(state)
730+
731+
for action, probability in zip(actions, probabilities):
732+
dfs(
733+
state.apply(action),
734+
probability * counterfactual_reach_probability,
735+
)
736+
else:
737+
for action in state.actions:
738+
dfs(state.apply(action), counterfactual_reach_probability)
739+
740+
dfs(self.initial_state, 1)
741+
742+
@cache
743+
def solve(state):
744+
if state.is_terminal():
745+
value = state.utility(player)
746+
elif state.is_chance() or state.player != player:
747+
if state.is_chance():
748+
actions, probabilities = zip(
749+
*state.chance_action_probabilities,
750+
)
751+
else:
752+
actions = state.actions
753+
probabilities = strategy_profile(state)
754+
755+
value = 0
756+
757+
for action, probability in zip(actions, probabilities):
758+
value += probability * solve(state.apply(action))
759+
else:
760+
value = solve2(state.infoset)
761+
762+
return value
763+
764+
@cache
765+
def solve2(infoset):
766+
values = defaultdict(int)
767+
768+
for state in states[infoset]:
769+
weight = counterfactual_reach_probabilities[state]
770+
771+
for i, action in enumerate(state.actions):
772+
values[i] += weight * solve(state.apply(action))
773+
774+
return max(values.values())
775+
776+
return solve(self.initial_state)
777+
778+
def nash_gap(self, strategy_profile):
779+
gap = 0
780+
781+
for player, value in zip(self.players, self.values(strategy_profile)):
782+
best_response_value = self.best_response_value(
783+
player,
784+
strategy_profile,
785+
)
786+
gap += best_response_value - value
787+
788+
return gap

0 commit comments

Comments
 (0)