Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,7 @@ jobs:
run: cd tests && python test_evaluation.py

- name: Run puzzle tests
run: cd tests && python test_puzzles.py
run: cd tests && python test_puzzles.py

- name: Run learning tests
run: cd tests && python test_learning.py
10 changes: 7 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
from engine import find_best_move_alpha_beta
from evaluation import evaluate
from src.agents import MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent
from src.agents import MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent, ValueIterationAgent, QLearningAgent

def print_board(board: chess.Board):
""" Print the chess board in a readable format. """
Expand Down Expand Up @@ -108,6 +108,10 @@ def play_game(agent_name="alphabeta", ai_depth=3):
ai_agent = AlphaBetaAgent(evaluate, depth=ai_depth, name="AlphaBetaAI", color=chess.BLACK)
elif agent_name.lower() == "expectimax":
ai_agent = ExpectimaxAgent(evaluate, depth=ai_depth, name="ExpectimaxAI", color=chess.BLACK)
elif agent_name.lower() == "valueiteration":
ai_agent = ValueIterationAgent()
elif agent_name.lower() == "qlearning":
ai_agent = QLearningAgent()
else:
print(f"Unknown agent: {agent_name}")
print("Available agents: minimax, alphabeta, expectimax")
Expand Down Expand Up @@ -163,7 +167,7 @@ def play_game(agent_name="alphabeta", ai_depth=3):

# Show AI stats
stats = ai_agent.get_search_info()
print(f"AI searched {stats['nodes_searched']} nodes in {elapsed:.2f}s")
# print(f"AI searched {stats['nodes_searched']} nodes in {elapsed:.2f}s")
move_number += 1

print_board(board)
Expand Down Expand Up @@ -193,7 +197,7 @@ def main():
"""Main entry point with command line argument parsing."""
parser = argparse.ArgumentParser(description="Play chess against different AI agents")
parser.add_argument("--agent", "-a",
choices=["minimax", "alphabeta", "expectimax"],
choices=["minimax", "alphabeta", "expectimax", "valueiteration", "qlearning"],
default="alphabeta",
help="AI agent to play against (default: alphabeta)")
parser.add_argument("--depth", "-d",
Expand Down
8 changes: 7 additions & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
from .base_agent import BaseAgent
from .human_agent import HumanAgent
from .search_agent import SearchAgent, MinimaxAgent, AlphaBetaAgent, ExpectimaxAgent
from .learning_agent import ReinforcementAgent
from .valueIteration_agent import ValueIterationAgent
from .qlearning_agent import QLearningAgent

__all__ = [
'BaseAgent',
'HumanAgent',
'SearchAgent',
'MinimaxAgent',
'AlphaBetaAgent',
'ExpectimaxAgent'
'ExpectimaxAgent',
'ValueIterationAgent',
'ReinforcementAgent',
'QLearningAgent'
]
104 changes: 104 additions & 0 deletions src/agents/learning_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Base classes for Learning agents (Value iteration, Q-learning)."""

from .base_agent import BaseAgent
import chess
from abc import abstractmethod

class ValueEstimationAgent(BaseAgent):
"""
V(s) = max_{a in actions} Q(s,a)
policy(s) = arg_max_{a in actions} Q(s,a)
"""
def __init__(self, alpha=1.0, epsilon=0.05, gamma=0.8, numTraining = 10, name="ValueEstimationAgent", color:chess.Color = chess.BLACK):
"""
alpha - learning rate
epsilon - exploration rate
gamma - discount factor
numTraining - number of training episodes, i.e. no learning after these many episodes
"""
super().__init__(name, color)
self.alpha = float(alpha)
self.epsilon = float(epsilon)
self.discount = float(gamma)
self.numTraining = int(numTraining)

def get_search_info(self):
pass

class ReinforcementAgent(ValueEstimationAgent):
"""
Abstract Reinforcemnt Agent: Q-Learning agents should inherit
"""
@abstractmethod
def update(self, board, action, nextState, reward):
"""
This class will call this function, which you write, after
observing a transition and reward
"""
raise NotImplementedError

def getLegalActions(self,board):
"""
Get the actions available for a given
state.
"""
return board.legal_moves

def observeTransition(self, board, action, nextState, deltaReward):
"""
Called by environment to inform agent that a transition has
been observed. This will result in a call to self.update
on the same arguments
"""
self.episodeRewards += deltaReward
self.update(board, action, nextState, deltaReward)

def doAction(self,state,action):
"""
Called by inherited class when
an action is taken in a state
"""
self.lastState = state
self.lastAction = action

def startEpisode(self):
"""
Start training episode
"""
self.lastState = None
self.lastAction = None
self.episodeRewards = 0.0

def stopEpisode(self):
"""
Stop training episode
"""
if self.episodesSoFar < self.numTraining:
self.accumTrainRewards += self.episodeRewards
else:
self.accumTestRewards += self.episodeRewards
self.episodesSoFar += 1
if self.episodesSoFar >= self.numTraining:
# Take off the training wheels
self.epsilon = 0.0 # no exploration
self.alpha = 0.0 # no learning

def isInTraining(self):
return self.episodesSoFar < self.numTraining

def isInTesting(self):
return not self.isInTraining()

def __init__(self, numTraining=100, epsilon=0.5, alpha=0.5, gamma=1, name="ReinforcementAgent", color:chess.Color = chess.BLACK):
"""
actionFn: Function which takes a state and returns the list of legal actions

alpha - learning rate
epsilon - exploration rate
gamma - discount factor
numTraining - number of training episodes, i.e. no learning after these many episodes
"""
super().__init__(alpha, epsilon, gamma, numTraining, name, color)
self.episodesSoFar = 0
self.accumTrainRewards = 0.0
self.accumTestRewards = 0.0
158 changes: 158 additions & 0 deletions src/agents/qlearning_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""QLearning Agent."""

from .learning_agent import ReinforcementAgent
from collections import defaultdict
import random
import chess
from utils import flipCoin
from evaluation import evaluate

class QLearningAgent(ReinforcementAgent):
"""Q-Learning Agent."""
def __init__(self, **args):
ReinforcementAgent.__init__(self, **args)
self.q_values = defaultdict(float)
self.train()

def choose_move(self, board: chess.Board) -> chess.Move:
"""
Determine best move based on current board state and QValues.

Args:
board: The chess board to get the move for.

Returns:
chess.Move: The move QLearningAgent decides upon
"""
return self.computeActionFromQValues(board)

def getQValue(self, board: chess.Board, action: chess.Move) -> float:
"""
Returns the Qvalue for a given state and action

Should return 0.0 if we have never seen a state
or the Q node value otherwise

Args:
board: The board to get the value for
action: The move to take from the state
"""
return self.q_values[(board.fen(), action)]


def computeValueFromQValues(self, board: chess.Board) -> float:
"""
Returns the highest value from the given board. Minimizes values if color is black, otherwise maximize.

Args:
board: The state from which to return the best value
"""
if board.is_game_over() or not board.legal_moves:
return 0.0

values = []
for move in board.legal_moves:
board.push(move)
values.append(self.getQValue(board, move))
board.pop()

return min(values) if self.color == chess.BLACK else max(values)


def computeActionFromQValues(self, board: chess.Board) -> chess.Move:
"""
Compute the best action to take in a state (Minimizes if color is black, maximizes otherwise).

Args:
board: The state from which to return the best action
"""
board.turn = self.color
if board.is_game_over() or not board.legal_moves:
return None

bestVal = float('inf') if self.color == chess.BLACK else float('-inf')
bestMoves = []
for move in board.legal_moves:
board.push(move)
curVal = self.getQValue(board, move)
board.pop()
if curVal == bestVal:
bestMoves.append(move)
if curVal < bestVal and self.color == chess.BLACK:
bestVal = curVal
bestMoves = [move]
elif curVal > bestVal and self.color == chess.WHITE:
bestVal = curVal
bestMoves = [move]

return random.choice(bestMoves)

def getAction(self, board: chess.Board) -> chess.Move:
"""
Compute the action to take in the current state.

Args:
board: the state from which the best action should be chosen.
"""
# Pick Action
board.turn = self.color
action = None
if flipCoin(self.epsilon):
action = random.choice(list(board.legal_moves))
else:
action = self.computeActionFromQValues(board)

return action

def final(self, board: chess.Board):
"""
Called after episode

Args:
The ending state of the board
"""
deltaReward = evaluate(board) - evaluate(self.lastState)
self.observeTransition(self.lastState, self.lastAction, board, deltaReward)
self.stopEpisode()

def registerInitialState(self, board: chess.Board):
"""Start training."""
self.startEpisode()

def update(self, board: chess.Board, action: chess.Move, nextBoard: chess.Board, reward: int):
"""
Performs state update

state = action => nextState and reward transition.

Args:
board: The current state of the board
action: The chosen action to
nextBoard: The board state after performing the action
reward: The reward for the action taken

"""
sample = reward + self.discount*self.computeValueFromQValues(nextBoard)
self.q_values[(board.fen(), action)] = (1-self.alpha)*self.getQValue(board, action)+self.alpha*sample

def train(self):
"""
Train a QLearning agent against an opponent making random moves
"""
for _ in range(self.numTraining):
board = chess.Board()
self.startEpisode()
while not board.is_game_over():
if board.turn == self.color:
state = board.copy()
action = self.getAction(state)
self.doAction(state, action)
board.push(action)
nextState = board
reward = evaluate(nextState) - evaluate(state)
self.observeTransition(state, action, nextState, reward)
else:
opp_moves = list(board.legal_moves)
if opp_moves:
board.push(random.choice(opp_moves))
self.final(board)
Loading