Skip to content

Commit 0c4c335

Browse files
Add workflow (#188)
* Add workflow Add issue template Switch to pathlib Fix action scaling in observation stacking Fix loading menu Update README.md Format project Convert list passed in torch tensor to numpy array Co-authored-by: werner-duvaud <40442230+werner-duvaud@users.noreply.github.com> * Update ci-testing.yaml * Fix requirements.txt and update ci-testing.yaml * Update ci-testing.yaml Co-authored-by: werner-duvaud <40442230+werner-duvaud@users.noreply.github.com>
1 parent 23a1f69 commit 0c4c335

25 files changed

+685
-197
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
name: 🐛 Bug Report
2+
3+
description: Create a report to help us reproduce and fix the bug
4+
labels: [bug]
5+
body:
6+
- type: markdown
7+
attributes:
8+
value: |
9+
Thank you for submitting a MuZero 🐛 Bug Report!
10+
11+
- type: checkboxes
12+
attributes:
13+
label: Search before asking
14+
description: >
15+
Please search the [issues](https://github.com/werner-duvaud/muzero-general/issues) to see if a similar bug report already exists.
16+
options:
17+
- label: >
18+
I have searched the MuZero [issues](https://github.com/werner-duvaud/muzero-general/issues) and found no similar bug report.
19+
required: true
20+
21+
- type: textarea
22+
attributes:
23+
label: 🐛 Describe the bug
24+
description: |
25+
Please provide a clear and concise description of what the bug is.
26+
validations:
27+
required: true
28+
29+
- type: textarea
30+
attributes:
31+
label: Add an example
32+
description: Provide console output with error messages and/or screenshots of the bug.
33+
placeholder: |
34+
💡 ProTip! Include as much information as possible (screenshots, logs, tracebacks etc.) to receive the most helpful response.
35+
validations:
36+
required: true
37+
38+
- type: textarea
39+
attributes:
40+
label: Environment
41+
description: Please specify the software and hardware you used to produce the bug.
42+
placeholder: |
43+
For example:
44+
- torch 1.9.0+cu111 CUDA:0 (A100-SXM4-40GB, 40536MiB)
45+
- OS: Ubuntu 20.04
46+
- Python: 3.9.0
47+
validations:
48+
required: false
49+
50+
- type: textarea
51+
attributes:
52+
label: Minimal Reproducible Example
53+
description: >
54+
This code will help us reproducing the issue so we can track the source of it.
55+
placeholder: |
56+
```
57+
# Code to reproduce the issue here
58+
```
59+
validations:
60+
required: false
61+
62+
- type: textarea
63+
attributes:
64+
label: Additional
65+
description: Anything else you would like to share?
66+
- type: markdown
67+
attributes:
68+
value: >
69+
Thanks for contributing 🎉!
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: 🚀 Feature Request
2+
3+
description: Submit a request for a new MuZero feature
4+
labels: [enhancement]
5+
body:
6+
- type: markdown
7+
attributes:
8+
value: |
9+
Thank you for submitting a MuZero 🚀 Feature Request!
10+
11+
- type: checkboxes
12+
attributes:
13+
label: Search before asking
14+
description: >
15+
Please search the [issues](https://github.com/werner-duvaud/muzero-general/issues) to see if a similar feature request already exists.
16+
options:
17+
- label: >
18+
I have searched the MuZero [issues](https://github.com/werner-duvaud/muzero-general/issues) and found no similar feature requests.
19+
required: true
20+
21+
- type: textarea
22+
attributes:
23+
label: Description
24+
description: A short description of your feature.
25+
placeholder: |
26+
What new feature would you like to see in MuZero?
27+
validations:
28+
required: true
29+
30+
- type: textarea
31+
attributes:
32+
label: Additional context
33+
description: >
34+
Add any other context or screenshots about the feature request.
35+
- type: markdown
36+
attributes:
37+
value: >
38+
Thanks for contributing 🎉!

.github/workflows/ci-testing.yaml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: CI testing
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
training-test:
7+
runs-on: ${{matrix.os}}
8+
strategy:
9+
fail-fast: false
10+
matrix:
11+
os: [ubuntu-latest]
12+
python-version: [3.7]
13+
14+
timeout-minutes: 90
15+
steps:
16+
- uses: actions/checkout@v2
17+
18+
- name: Set up Python ${{matrix.python-version}}
19+
uses: actions/setup-python@v2
20+
with:
21+
python-version: ${{matrix.python-version}}
22+
23+
- name: Install black
24+
run: "pip install black"
25+
26+
- name: Run black
27+
run: "black --check --diff ."
28+
29+
- name: Install dependencies
30+
run: |
31+
pip install -r requirements.txt
32+
33+
- name: Training test
34+
shell: bash
35+
run: |
36+
# Launch cartpole experiment and store the last reward of the training
37+
python muzero.py cartpole '{"training_steps": 7500}' 2>&1 | tee log.txt
38+
39+
- name: Archive log artifact
40+
uses: actions/upload-artifact@v2
41+
with:
42+
name: training_test_logs
43+
path: log.txt
44+
45+
- name: Retrieve training test log
46+
uses: actions/download-artifact@v2
47+
with:
48+
name: training_test_logs
49+
50+
- name: Check reward
51+
shell: bash
52+
run: |
53+
# Retrieve last reward
54+
BEST_REWARD=$(cat log.txt | sed -n -E 's/^.*reward: ([0-9]+).*$/\1/p' | sort -n | tail -1)
55+
56+
# Display best reward
57+
echo "Best reward of cartpole training: " $BEST_REWARD
58+
59+
# Validate reward value
60+
if ((BEST_REWARD < 250)); then
61+
exit 1
62+
fi

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
![license MIT](https://img.shields.io/badge/licence-MIT-green)
66
[![discord badge](https://img.shields.io/badge/discord-join-6E60EF)](https://discord.gg/GB2vwsF)
77

8+
![ci-testing workflow](https://github.com/werner-duvaud/muzero-general/workflows/CI%20testing/badge.svg)
9+
810
# MuZero General
911

10-
A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) (Nov 2019) and the associated [pseudocode](https://arxiv.org/src/1911.08265v2/anc/pseudocode.py).
12+
A commented and [documented](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) implementation of MuZero based on the Google DeepMind [paper](https://arxiv.org/abs/1911.08265) (Schrittwieser et al., Nov 2019) and the associated [pseudocode](https://arxiv.org/src/1911.08265v2/anc/pseudocode.py).
1113
It is designed to be easily adaptable for every games or reinforcement learning environments (like [gym](https://github.com/openai/gym)). You only need to add a [game file](https://github.com/werner-duvaud/muzero-general/tree/master/games) with the hyperparameters and the game class. Please refer to the [documentation](https://github.com/werner-duvaud/muzero-general/wiki/MuZero-Documentation) and the [example](https://github.com/werner-duvaud/muzero-general/blob/master/games/cartpole.py).
14+
This implementation is primarily for educational purpose.\
15+
[Explanatory video of MuZero](https://youtu.be/We20YSAJZSE)
1216

1317
MuZero is a state of the art RL algorithm for board games (Chess, Go, ...) and Atari games.
1418
It is the successor to [AlphaZero](https://arxiv.org/abs/1712.01815) but without any knowledge of the environment underlying dynamics. MuZero learns a model of the environment and uses an internal representation that contains only the useful information for predicting the reward, value, policy and transitions. MuZero is also close to [Value prediction networks](https://arxiv.org/abs/1707.03497). See [How it works](https://github.com/werner-duvaud/muzero-general/wiki/How-MuZero-works).
@@ -28,14 +32,13 @@ It is the successor to [AlphaZero](https://arxiv.org/abs/1712.01815) but without
2832
* [ ] Windows support (Experimental / Workaround: Use the [notebook](https://github.com/werner-duvaud/muzero-general/blob/master/notebook.ipynb) in [Google Colab](https://colab.research.google.com))
2933

3034
### Further improvements
31-
These improvements are active research, they are personal ideas and go beyond MuZero paper. We are open to contributions and other ideas.
35+
Here is a list of features which could be interesting to add but which are not in MuZero's paper. We are open to contributions and other ideas.
3236

3337
* [x] [Hyperparameter search](https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization)
3438
* [x] [Continuous action space](https://github.com/werner-duvaud/muzero-general/tree/continuous)
3539
* [x] [Tool to understand the learned model](https://github.com/werner-duvaud/muzero-general/blob/master/diagnose_model.py)
36-
* [ ] Support of stochastic environments
40+
* [ ] Batch MCTS
3741
* [ ] Support of more than two player games
38-
* [ ] RL tricks (Never Give Up, Adaptive Exploration, ...)
3942

4043
## Demo
4144

@@ -96,6 +99,11 @@ tensorboard --logdir ./results
9699

97100
You can adapt the configurations of each game by editing the `MuZeroConfig` class of the respective file in the [games folder](https://github.com/werner-duvaud/muzero-general/tree/master/games).
98101

102+
## Related work
103+
104+
* [EfficientZero](https://arxiv.org/abs/2111.00210) (Weirui Ye, Shaohuai Liu, Thanard Kurutach, Pieter Abbeel, Yang Gao)
105+
* [Sampled MuZero](https://arxiv.org/abs/2104.06303) (Thomas Hubert, Julian Schrittwieser, Ioannis Antonoglou, Mohammadamin Barekatain, Simon Schmitt, David Silver)
106+
99107
## Authors
100108

101109
* Werner Duvaud

diagnose_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,4 +364,4 @@ def plot_trajectory(self):
364364
ax.set(ylabel="Timestep")
365365
ax.set_title(name)
366366

367-
plt.show(block=False)
367+
plt.show(block=False)

games/abstract_game.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, seed=None):
1414
def step(self, action):
1515
"""
1616
Apply action to the game.
17-
17+
1818
Args:
1919
action : action of the action_space to take.
2020
@@ -28,7 +28,7 @@ def to_play(self):
2828
Return the current player.
2929
3030
Returns:
31-
The current player, it should be an element of the players list in the config.
31+
The current player, it should be an element of the players list in the config.
3232
"""
3333
return 0
3434

@@ -37,7 +37,7 @@ def legal_actions(self):
3737
"""
3838
Should return the legal actions at each turn, if it is not available, it can return
3939
the whole action space. At each turn, the game have to be able to handle one of returned actions.
40-
40+
4141
For complex game where calculating legal moves is too long, the idea is to define the legal actions
4242
equal to the action space but to return a negative reward if the action is illegal.
4343
@@ -50,7 +50,7 @@ def legal_actions(self):
5050
def reset(self):
5151
"""
5252
Reset the game for a new game.
53-
53+
5454
Returns:
5555
Initial observation of the game.
5656
"""
@@ -79,7 +79,7 @@ def human_to_action(self):
7979
"""
8080
choice = input(f"Enter the action to play for the player {self.to_play()}: ")
8181
while int(choice) not in self.legal_actions():
82-
choice = input("Ilegal action. Enter another action : ")
82+
choice = input("Illegal action. Enter another action : ")
8383
return int(choice)
8484

8585
def expert_agent(self):

games/atari.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import datetime
2-
import os
2+
import pathlib
33

44
import gym
55
import numpy
@@ -15,6 +15,7 @@
1515

1616
class MuZeroConfig:
1717
def __init__(self):
18+
# fmt: off
1819
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1920

2021
self.seed = 0 # Seed for numpy, torch and the game
@@ -78,7 +79,7 @@ def __init__(self):
7879

7980

8081
### Training
81-
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs
82+
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
8283
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
8384
self.training_steps = int(1000e3) # Total number of training steps (ie weights update according to a batch)
8485
self.batch_size = 1024 # Number of parts of games to train on at each training step
@@ -114,7 +115,7 @@ def __init__(self):
114115
self.self_play_delay = 0 # Number of seconds to wait after each played game
115116
self.training_delay = 0 # Number of seconds to wait after each training step
116117
self.ratio = None # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it
117-
118+
# fmt: on
118119

119120
def visit_softmax_temperature_fn(self, trained_steps):
120121
"""
@@ -145,7 +146,7 @@ def __init__(self, seed=None):
145146
def step(self, action):
146147
"""
147148
Apply action to the game.
148-
149+
149150
Args:
150151
action : action of the action_space to take.
151152
@@ -162,9 +163,9 @@ def legal_actions(self):
162163
"""
163164
Should return the legal actions at each turn, if it is not available, it can return
164165
the whole action space. At each turn, the game have to be able to handle one of returned actions.
165-
166+
166167
For complex game where calculating legal moves is too long, the idea is to define the legal actions
167-
equal to the action space but to return a negative reward if the action is illegal.
168+
equal to the action space but to return a negative reward if the action is illegal.
168169
169170
Returns:
170171
An array of integers, subset of the action space.
@@ -174,7 +175,7 @@ def legal_actions(self):
174175
def reset(self):
175176
"""
176177
Reset the game for a new game.
177-
178+
178179
Returns:
179180
Initial observation of the game.
180181
"""
@@ -196,4 +197,3 @@ def render(self):
196197
"""
197198
self.env.render()
198199
input("Press enter to take a step ")
199-

0 commit comments

Comments
 (0)