Skip to content

Commit efcf3b9

Browse files
authored
Merge pull request #11 from vnthanhdng/feature/learningAgents
Feat: Add learning agents (value iteration, qlearning)
2 parents 2da2d01 + 29e21a4 commit efcf3b9

File tree

8 files changed

+648
-5
lines changed

8 files changed

+648
-5
lines changed

.github/workflows/python-tests.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,7 @@ jobs:
3737
run: cd tests && python test_evaluation.py
3838

3939
- name: Run puzzle tests
40-
run: cd tests && python test_puzzles.py
40+
run: cd tests && python test_puzzles.py
41+
42+
- name: Run learning tests
43+
run: cd tests && python test_learning.py

main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import argparse
77
from engine import find_best_move_alpha_beta
88
from evaluation import evaluate
9-
from src.agents import MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent
9+
from src.agents import MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent, ValueIterationAgent, QLearningAgent
1010

1111
def print_board(board: chess.Board):
1212
""" Print the chess board in a readable format. """
@@ -108,6 +108,10 @@ def play_game(agent_name="alphabeta", ai_depth=3):
108108
ai_agent = AlphaBetaAgent(evaluate, depth=ai_depth, name="AlphaBetaAI", color=chess.BLACK)
109109
elif agent_name.lower() == "expectimax":
110110
ai_agent = ExpectimaxAgent(evaluate, depth=ai_depth, name="ExpectimaxAI", color=chess.BLACK)
111+
elif agent_name.lower() == "valueiteration":
112+
ai_agent = ValueIterationAgent()
113+
elif agent_name.lower() == "qlearning":
114+
ai_agent = QLearningAgent()
111115
else:
112116
print(f"Unknown agent: {agent_name}")
113117
print("Available agents: minimax, alphabeta, expectimax")
@@ -163,7 +167,7 @@ def play_game(agent_name="alphabeta", ai_depth=3):
163167

164168
# Show AI stats
165169
stats = ai_agent.get_search_info()
166-
print(f"AI searched {stats['nodes_searched']} nodes in {elapsed:.2f}s")
170+
# print(f"AI searched {stats['nodes_searched']} nodes in {elapsed:.2f}s")
167171
move_number += 1
168172

169173
print_board(board)
@@ -193,7 +197,7 @@ def main():
193197
"""Main entry point with command line argument parsing."""
194198
parser = argparse.ArgumentParser(description="Play chess against different AI agents")
195199
parser.add_argument("--agent", "-a",
196-
choices=["minimax", "alphabeta", "expectimax"],
200+
choices=["minimax", "alphabeta", "expectimax", "valueiteration", "qlearning"],
197201
default="alphabeta",
198202
help="AI agent to play against (default: alphabeta)")
199203
parser.add_argument("--depth", "-d",

src/agents/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
from .base_agent import BaseAgent
33
from .human_agent import HumanAgent
44
from .search_agent import SearchAgent, MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent
5+
from .learning_agent import ReinforcementAgent
6+
from .valueIteration_agent import ValueIterationAgent
7+
from .qlearning_agent import QLearningAgent
58

69
__all__ = [
710
'BaseAgent',
811
'HumanAgent',
912
'SearchAgent',
1013
'MinimaxAgent',
1114
'AlphaBetaAgent',
12-
'ExpectimaxAgent'
15+
'ExpectimaxAgent',
16+
'ValueIterationAgent',
17+
'ReinforcementAgent',
18+
'QLearningAgent'
1319
]

src/agents/learning_agent.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Base classes for Learning agents (Value iteration, Q-learning)."""
2+
3+
from .base_agent import BaseAgent
4+
import chess
5+
from abc import abstractmethod
6+
7+
class ValueEstimationAgent(BaseAgent):
8+
"""
9+
V(s) = max_{a in actions} Q(s,a)
10+
policy(s) = arg_max_{a in actions} Q(s,a)
11+
"""
12+
def __init__(self, alpha=1.0, epsilon=0.05, gamma=0.8, numTraining = 10, name="ValueEstimationAgent", color:chess.Color = chess.BLACK):
13+
"""
14+
alpha - learning rate
15+
epsilon - exploration rate
16+
gamma - discount factor
17+
numTraining - number of training episodes, i.e. no learning after these many episodes
18+
"""
19+
super().__init__(name, color)
20+
self.alpha = float(alpha)
21+
self.epsilon = float(epsilon)
22+
self.discount = float(gamma)
23+
self.numTraining = int(numTraining)
24+
25+
def get_search_info(self):
26+
pass
27+
28+
class ReinforcementAgent(ValueEstimationAgent):
29+
"""
30+
Abstract Reinforcemnt Agent: Q-Learning agents should inherit
31+
"""
32+
@abstractmethod
33+
def update(self, board, action, nextState, reward):
34+
"""
35+
This class will call this function, which you write, after
36+
observing a transition and reward
37+
"""
38+
raise NotImplementedError
39+
40+
def getLegalActions(self,board):
41+
"""
42+
Get the actions available for a given
43+
state.
44+
"""
45+
return board.legal_moves
46+
47+
def observeTransition(self, board, action, nextState, deltaReward):
48+
"""
49+
Called by environment to inform agent that a transition has
50+
been observed. This will result in a call to self.update
51+
on the same arguments
52+
"""
53+
self.episodeRewards += deltaReward
54+
self.update(board, action, nextState, deltaReward)
55+
56+
def doAction(self,state,action):
57+
"""
58+
Called by inherited class when
59+
an action is taken in a state
60+
"""
61+
self.lastState = state
62+
self.lastAction = action
63+
64+
def startEpisode(self):
65+
"""
66+
Start training episode
67+
"""
68+
self.lastState = None
69+
self.lastAction = None
70+
self.episodeRewards = 0.0
71+
72+
def stopEpisode(self):
73+
"""
74+
Stop training episode
75+
"""
76+
if self.episodesSoFar < self.numTraining:
77+
self.accumTrainRewards += self.episodeRewards
78+
else:
79+
self.accumTestRewards += self.episodeRewards
80+
self.episodesSoFar += 1
81+
if self.episodesSoFar >= self.numTraining:
82+
# Take off the training wheels
83+
self.epsilon = 0.0 # no exploration
84+
self.alpha = 0.0 # no learning
85+
86+
def isInTraining(self):
87+
return self.episodesSoFar < self.numTraining
88+
89+
def isInTesting(self):
90+
return not self.isInTraining()
91+
92+
def __init__(self, numTraining=100, epsilon=0.5, alpha=0.5, gamma=1, name="ReinforcementAgent", color:chess.Color = chess.BLACK):
93+
"""
94+
actionFn: Function which takes a state and returns the list of legal actions
95+
96+
alpha - learning rate
97+
epsilon - exploration rate
98+
gamma - discount factor
99+
numTraining - number of training episodes, i.e. no learning after these many episodes
100+
"""
101+
super().__init__(alpha, epsilon, gamma, numTraining, name, color)
102+
self.episodesSoFar = 0
103+
self.accumTrainRewards = 0.0
104+
self.accumTestRewards = 0.0

src/agents/qlearning_agent.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""QLearning Agent."""
2+
3+
from .learning_agent import ReinforcementAgent
4+
from collections import defaultdict
5+
import random
6+
import chess
7+
from utils import flipCoin
8+
from evaluation import evaluate
9+
10+
class QLearningAgent(ReinforcementAgent):
11+
"""Q-Learning Agent."""
12+
def __init__(self, **args):
13+
ReinforcementAgent.__init__(self, **args)
14+
self.q_values = defaultdict(float)
15+
self.train()
16+
17+
def choose_move(self, board: chess.Board) -> chess.Move:
18+
"""
19+
Determine best move based on current board state and QValues.
20+
21+
Args:
22+
board: The chess board to get the move for.
23+
24+
Returns:
25+
chess.Move: The move QLearningAgent decides upon
26+
"""
27+
return self.computeActionFromQValues(board)
28+
29+
def getQValue(self, board: chess.Board, action: chess.Move) -> float:
30+
"""
31+
Returns the Qvalue for a given state and action
32+
33+
Should return 0.0 if we have never seen a state
34+
or the Q node value otherwise
35+
36+
Args:
37+
board: The board to get the value for
38+
action: The move to take from the state
39+
"""
40+
return self.q_values[(board.fen(), action)]
41+
42+
43+
def computeValueFromQValues(self, board: chess.Board) -> float:
44+
"""
45+
Returns the highest value from the given board. Minimizes values if color is black, otherwise maximize.
46+
47+
Args:
48+
board: The state from which to return the best value
49+
"""
50+
if board.is_game_over() or not board.legal_moves:
51+
return 0.0
52+
53+
values = []
54+
for move in board.legal_moves:
55+
board.push(move)
56+
values.append(self.getQValue(board, move))
57+
board.pop()
58+
59+
return min(values) if self.color == chess.BLACK else max(values)
60+
61+
62+
def computeActionFromQValues(self, board: chess.Board) -> chess.Move:
63+
"""
64+
Compute the best action to take in a state (Minimizes if color is black, maximizes otherwise).
65+
66+
Args:
67+
board: The state from which to return the best action
68+
"""
69+
board.turn = self.color
70+
if board.is_game_over() or not board.legal_moves:
71+
return None
72+
73+
bestVal = float('inf') if self.color == chess.BLACK else float('-inf')
74+
bestMoves = []
75+
for move in board.legal_moves:
76+
board.push(move)
77+
curVal = self.getQValue(board, move)
78+
board.pop()
79+
if curVal == bestVal:
80+
bestMoves.append(move)
81+
if curVal < bestVal and self.color == chess.BLACK:
82+
bestVal = curVal
83+
bestMoves = [move]
84+
elif curVal > bestVal and self.color == chess.WHITE:
85+
bestVal = curVal
86+
bestMoves = [move]
87+
88+
return random.choice(bestMoves)
89+
90+
def getAction(self, board: chess.Board) -> chess.Move:
91+
"""
92+
Compute the action to take in the current state.
93+
94+
Args:
95+
board: the state from which the best action should be chosen.
96+
"""
97+
# Pick Action
98+
board.turn = self.color
99+
action = None
100+
if flipCoin(self.epsilon):
101+
action = random.choice(list(board.legal_moves))
102+
else:
103+
action = self.computeActionFromQValues(board)
104+
105+
return action
106+
107+
def final(self, board: chess.Board):
108+
"""
109+
Called after episode
110+
111+
Args:
112+
The ending state of the board
113+
"""
114+
deltaReward = evaluate(board) - evaluate(self.lastState)
115+
self.observeTransition(self.lastState, self.lastAction, board, deltaReward)
116+
self.stopEpisode()
117+
118+
def registerInitialState(self, board: chess.Board):
119+
"""Start training."""
120+
self.startEpisode()
121+
122+
def update(self, board: chess.Board, action: chess.Move, nextBoard: chess.Board, reward: int):
123+
"""
124+
Performs state update
125+
126+
state = action => nextState and reward transition.
127+
128+
Args:
129+
board: The current state of the board
130+
action: The chosen action to
131+
nextBoard: The board state after performing the action
132+
reward: The reward for the action taken
133+
134+
"""
135+
sample = reward + self.discount*self.computeValueFromQValues(nextBoard)
136+
self.q_values[(board.fen(), action)] = (1-self.alpha)*self.getQValue(board, action)+self.alpha*sample
137+
138+
def train(self):
139+
"""
140+
Train a QLearning agent against an opponent making random moves
141+
"""
142+
for _ in range(self.numTraining):
143+
board = chess.Board()
144+
self.startEpisode()
145+
while not board.is_game_over():
146+
if board.turn == self.color:
147+
state = board.copy()
148+
action = self.getAction(state)
149+
self.doAction(state, action)
150+
board.push(action)
151+
nextState = board
152+
reward = evaluate(nextState) - evaluate(state)
153+
self.observeTransition(state, action, nextState, reward)
154+
else:
155+
opp_moves = list(board.legal_moves)
156+
if opp_moves:
157+
board.push(random.choice(opp_moves))
158+
self.final(board)

0 commit comments

Comments
 (0)