11from abc import ABC , abstractmethod
2+ from collections import defaultdict
23from dataclasses import dataclass
3- from functools import partial
4+ from functools import cache , partial
45from itertools import permutations
56from math import factorial
67from typing import Any
@@ -73,8 +74,11 @@ def nash_gap(self, *strategies):
7374
7475 for i , value in enumerate (self .values (* strategies )):
7576 opponent_strategies = strategies [:i ] + strategies [i + 1 :]
76- _ , br_value = self .best_response (i , * opponent_strategies )
77- gap += br_value - value
77+ _ , best_response_value = self .best_response (
78+ i ,
79+ * opponent_strategies ,
80+ )
81+ gap += best_response_value - value
7882
7983 return gap
8084
@@ -86,8 +90,11 @@ def cce_gap(self, *strategies):
8690 average_opponent_strategies = (
8791 average_strategies [:i ] + average_strategies [i + 1 :]
8892 )
89- _ , br_value = self .best_response (i , * average_opponent_strategies )
90- gap += br_value - value
93+ _ , best_response_value = self .best_response (
94+ i ,
95+ * average_opponent_strategies ,
96+ )
97+ gap += best_response_value - value
9198
9299 return gap
93100
@@ -613,3 +620,169 @@ def correlated_value(self, player, *strategies):
613620
614621 def best_response (self , player , * opponent_strategies ):
615622 raise NotImplementedError
623+
624+
625+ class ExtensiveFormGame2 (ABC ):
626+ """Extensive-form game (EFG)."""
627+
628+ @dataclass (frozen = True )
629+ class State :
630+ """State of an extensive-form game."""
631+
632+ @property
633+ @abstractmethod
634+ def utilities (self ):
635+ pass
636+
637+ @property
638+ @abstractmethod
639+ def chance_action_probabilities (self ):
640+ pass
641+
642+ @property
643+ @abstractmethod
644+ def actions (self ):
645+ pass
646+
647+ @property
648+ @abstractmethod
649+ def infoset (self ):
650+ pass
651+
652+ @property
653+ @abstractmethod
654+ def player (self ):
655+ pass
656+
657+ @abstractmethod
658+ def is_terminal (self ):
659+ pass
660+
661+ @abstractmethod
662+ def is_chance (self ):
663+ pass
664+
665+ @abstractmethod
666+ def utility (self , player ):
667+ pass
668+
669+ @abstractmethod
670+ def apply (self , action ):
671+ pass
672+
673+ @property
674+ @abstractmethod
675+ def players (self ):
676+ pass
677+
678+ @property
679+ @abstractmethod
680+ def initial_state (self ):
681+ pass
682+
683+ def values (self , strategy_profile , state = None ):
684+ if state is None :
685+ values = self .values (strategy_profile , self .initial_state )
686+ elif state .is_terminal ():
687+ values = state .utilities
688+ else :
689+ if state .is_chance ():
690+ actions , probabilities = zip (
691+ * state .chance_action_probabilities ,
692+ )
693+ else :
694+ actions = state .actions
695+ probabilities = strategy_profile (state )
696+
697+ values = 0
698+
699+ for action , probability in zip (actions , probabilities ):
700+ values += (
701+ probability
702+ * self .values (strategy_profile , state .apply (action ))
703+ )
704+
705+ return values
706+
707+ def best_response_value (self , player , strategy_profile ):
708+ states = defaultdict (list )
709+ counterfactual_reach_probabilities = {}
710+
711+ def dfs (state , counterfactual_reach_probability ):
712+ counterfactual_reach_probabilities [state ] = (
713+ counterfactual_reach_probability
714+ )
715+
716+ if state .is_terminal ():
717+ return
718+
719+ if not state .is_chance ():
720+ states [state .infoset ].append (state )
721+
722+ if state .is_chance () or state .player != player :
723+ if state .is_chance ():
724+ actions , probabilities = zip (
725+ * state .chance_action_probabilities ,
726+ )
727+ else :
728+ actions = state .actions
729+ probabilities = strategy_profile (state )
730+
731+ for action , probability in zip (actions , probabilities ):
732+ dfs (
733+ state .apply (action ),
734+ probability * counterfactual_reach_probability ,
735+ )
736+ else :
737+ for action in state .actions :
738+ dfs (state .apply (action ), counterfactual_reach_probability )
739+
740+ dfs (self .initial_state , 1 )
741+
742+ @cache
743+ def solve (state ):
744+ if state .is_terminal ():
745+ value = state .utility (player )
746+ elif state .is_chance () or state .player != player :
747+ if state .is_chance ():
748+ actions , probabilities = zip (
749+ * state .chance_action_probabilities ,
750+ )
751+ else :
752+ actions = state .actions
753+ probabilities = strategy_profile (state )
754+
755+ value = 0
756+
757+ for action , probability in zip (actions , probabilities ):
758+ value += probability * solve (state .apply (action ))
759+ else :
760+ value = solve2 (state .infoset )
761+
762+ return value
763+
764+ @cache
765+ def solve2 (infoset ):
766+ values = defaultdict (int )
767+
768+ for state in states [infoset ]:
769+ weight = counterfactual_reach_probabilities [state ]
770+
771+ for i , action in enumerate (state .actions ):
772+ values [i ] += weight * solve (state .apply (action ))
773+
774+ return max (values .values ())
775+
776+ return solve (self .initial_state )
777+
778+ def nash_gap (self , strategy_profile ):
779+ gap = 0
780+
781+ for player , value in zip (self .players , self .values (strategy_profile )):
782+ best_response_value = self .best_response_value (
783+ player ,
784+ strategy_profile ,
785+ )
786+ gap += best_response_value - value
787+
788+ return gap
0 commit comments