Skip to content

Commit 90884dd

Browse files
committed
Temporal Difference Exercises
1 parent d595006 commit 90884dd

File tree

10 files changed

+130
-193
lines changed

10 files changed

+130
-193
lines changed

.idea/.gitignore

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/deep-reinforcement-learning.iml

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/profiles_settings.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

monte-carlo/blackjack.py

Whitespace-only changes.
Lines changed: 30 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,20 @@
1-
import sys
1+
from collections import defaultdict
2+
23
import gym
34
import numpy as np
4-
from collections import defaultdict, deque
5-
import matplotlib.pyplot as plt
65

76
import check_test
87
from plot_utils import plot_values
98

109
env = gym.make('CliffWalking-v0')
1110

12-
# print(env.action_space)
13-
# print(env.observation_space)
14-
#
15-
# # define the optimal state-value function
16-
# V_opt = np.zeros((4, 12))
17-
# V_opt[0:13][0] = -np.arange(3, 15)[::-1]
18-
# V_opt[0:13][1] = -np.arange(3, 15)[::-1] + 1
19-
# V_opt[0:13][2] = -np.arange(3, 15)[::-1] + 2
20-
# V_opt[3][0] = -13
21-
#
22-
# plot_values(V_opt)
23-
24-
25-
def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
11+
def q_learning(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
2612
# decide epsilon
2713
epsilon = epsilon_start
2814
epsilon_min = 0.1
29-
epsilon_decay = 0.9999
15+
epsilon_decay = 0.997
16+
17+
nA = 4
3018

3119
# initialize action-value function (empty dictionary of arrays)
3220
Q = defaultdict(lambda: np.zeros(env.nA))
@@ -35,53 +23,35 @@ def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
3523
for i_episode in range(1, num_episodes + 1):
3624

3725
# monitor progress
38-
if i_episode % 499999 == 0:
39-
print("\rEpisode {}/{}".format(i_episode, num_episodes), end="")
40-
print (str(Q))
26+
# if i_episode % 1 == 0:
27+
# print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
28+
#
29+
# sys.stdout.flush()
4130

31+
if i_episode % 100 == 0:
32+
print(f"Episode {i_episode}: Epsilon = {epsilon:.4f}")
4233

4334
# set the value of epsilon
4435
epsilon = max(epsilon * epsilon_decay, epsilon_min)
4536

4637
# generate episode
47-
episode = generate_episode(env=env, Q=Q, epsilon=epsilon, nA=4)
48-
49-
Q = update_q(episode, Q, alpha, gamma)
50-
return Q
51-
52-
53-
def generate_episode(env, Q, epsilon, nA):
54-
episode = []
55-
state, _ = env.reset()
56-
if isinstance(state, dict):
57-
state = tuple(sorted(state.items()))
38+
state, _ = env.reset()
5839

59-
action = np.random.choice(np.arange(nA),
60-
p=get_probs(Q[state], epsilon, nA)) if state in Q else env.action_space.sample()
40+
while True:
41+
action = np.random.choice(np.arange(nA), p=get_probs(Q[state], epsilon, nA))
6142

62-
while True:
63-
if isinstance(state, tuple):
64-
state = state[0] # Extract actual state if (state, info) is returned
43+
next_state, reward, terminated, truncated, _ = env.step(action) # ✅ New API
44+
if next_state not in Q:
45+
Q[next_state] = np.zeros(nA) # ✅ Ensure Q-value initialization
6546

66-
# ✅ Convert state to tuple if it’s a dictionary
67-
if isinstance(state, dict):
68-
state = tuple(sorted(state.items()))
47+
if terminated or truncated:
48+
break
6949

70-
next_state, reward, terminated, truncated, _ = env.step(action) # ✅ New API
71-
if next_state not in Q:
72-
Q[next_state] = np.zeros(nA) # ✅ Ensure Q-value initialization
50+
next_Q = 0 if terminated else np.max(Q[next_state])
51+
Q[state][action] += alpha * (reward + gamma * next_Q - Q[state][action])
7352

74-
next_action = np.random.choice(np.arange(nA), p=get_probs(Q[next_state], epsilon, nA))
75-
76-
episode.append((state, action, reward))
77-
78-
if terminated or truncated:
79-
break
80-
81-
state = next_state # ✅ Track next state
82-
action = next_action # ✅ Track next action
83-
84-
return episode
53+
state = next_state
54+
return Q
8555

8656
def get_probs(Q_s, epsilon, nA):
8757
""" obtains the action probabilities corresponding to epsilon-greedy policy """
@@ -90,37 +60,16 @@ def get_probs(Q_s, epsilon, nA):
9060
policy_states[best_action] = 1 - epsilon + (epsilon / nA)
9161
return policy_states
9262

93-
def pick_action(epsilon, Q, next_state):
94-
if np.random.rand() < epsilon:
95-
next_action = env.action_space.sample() # Explore (random action)
96-
else:
97-
next_action = np.argmax(Q[next_state]) # Exploit (best action)
98-
99-
return next_action
100-
101-
def update_q(episode, Q, alpha, gamma):
102-
""" updates the action-value function estimate using the most recent episode """
103-
states, actions, rewards = zip(*episode)
104-
# prepare for discounting
105-
for i in range(len(states) - 1): # Ignore last step
106-
state, action = states[i], actions[i]
107-
next_state, next_action = states[i + 1], actions[i + 1] # ✅ Use episode step
108-
109-
old_Q = Q[state][action]
110-
next_Q = Q[next_state][next_action] # ✅ Correct SARSA update
111-
Q[state][action] = old_Q + alpha * (rewards[i] + gamma * next_Q - old_Q)
112-
return Q
113-
11463

11564
# obtain the estimated optimal policy and corresponding action-value function
116-
Q_sarsa = sarsa(env, 500000, .01)
65+
Q_q_learning = q_learning(env, 5000, .01)
11766

11867
# print the estimated optimal policy
119-
policy_sarsa = np.array([np.argmax(Q_sarsa[key]) if key in Q_sarsa else -1 for key in np.arange(48)]).reshape(4,12)
120-
check_test.run_check('td_control_check', policy_sarsa)
68+
policy_q_learning = np.array([np.argmax(Q_q_learning[key]) if key in Q_q_learning else -1 for key in np.arange(48)]).reshape(4,12)
69+
check_test.run_check('td_control_check', policy_q_learning)
12170
print("\nEstimated Optimal Policy (UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3, N/A = -1):")
122-
print(policy_sarsa)
71+
print(policy_q_learning)
12372

12473
# plot the estimated optimal state-value function
125-
V_sarsa = ([np.max(Q_sarsa[key]) if key in Q_sarsa else 0 for key in np.arange(48)])
126-
plot_values(V_sarsa)
74+
V_q_learning = ([np.max(Q_q_learning[key]) if key in Q_q_learning else 0 for key in np.arange(48)])
75+
plot_values(V_q_learning)

temporal-difference/TD_CliffWalking_SARSA_Solution.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,11 @@
99

1010
env = gym.make('CliffWalking-v0')
1111

12-
# print(env.action_space)
13-
# print(env.observation_space)
14-
#
15-
# # define the optimal state-value function
16-
# V_opt = np.zeros((4, 12))
17-
# V_opt[0:13][0] = -np.arange(3, 15)[::-1]
18-
# V_opt[0:13][1] = -np.arange(3, 15)[::-1] + 1
19-
# V_opt[0:13][2] = -np.arange(3, 15)[::-1] + 2
20-
# V_opt[3][0] = -13
21-
#
22-
# plot_values(V_opt)
23-
24-
25-
def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
12+
def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=0.5):
2613
# decide epsilon
2714
epsilon = epsilon_start
2815
epsilon_min = 0.1
29-
epsilon_decay = 0.9999
16+
epsilon_decay = 0.99
3017

3118
# initialize action-value function (empty dictionary of arrays)
3219
Q = defaultdict(lambda: np.zeros(env.nA))
@@ -35,10 +22,9 @@ def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
3522
for i_episode in range(1, num_episodes + 1):
3623

3724
# monitor progress
38-
if i_episode % 499999 == 0:
39-
print("\rEpisode {}/{}".format(i_episode, num_episodes), end="")
40-
print (str(Q))
41-
25+
if i_episode % 1 == 0:
26+
print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
27+
sys.stdout.flush()
4228

4329
# set the value of epsilon
4430
epsilon = max(epsilon * epsilon_decay, epsilon_min)
@@ -53,20 +39,12 @@ def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
5339
def generate_episode(env, Q, epsilon, nA):
5440
episode = []
5541
state, _ = env.reset()
56-
if isinstance(state, dict):
57-
state = tuple(sorted(state.items()))
5842

59-
action = np.random.choice(np.arange(nA),
60-
p=get_probs(Q[state], epsilon, nA)) if state in Q else env.action_space.sample()
43+
action = np.random.choice(np.arange(nA), p=get_probs(Q[state], epsilon, nA)) if state not in Q else np.argmax(
44+
Q[state]
45+
)
6146

6247
while True:
63-
if isinstance(state, tuple):
64-
state = state[0] # Extract actual state if (state, info) is returned
65-
66-
# ✅ Convert state to tuple if it’s a dictionary
67-
if isinstance(state, dict):
68-
state = tuple(sorted(state.items()))
69-
7048
next_state, reward, terminated, truncated, _ = env.step(action) # ✅ New API
7149
if next_state not in Q:
7250
Q[next_state] = np.zeros(nA) # ✅ Ensure Q-value initialization
@@ -113,7 +91,7 @@ def update_q(episode, Q, alpha, gamma):
11391

11492

11593
# obtain the estimated optimal policy and corresponding action-value function
116-
Q_sarsa = sarsa(env, 500000, .01)
94+
Q_sarsa = sarsa(env, 5000, .01)
11795

11896
# print the estimated optimal policy
11997
policy_sarsa = np.array([np.argmax(Q_sarsa[key]) if key in Q_sarsa else -1 for key in np.arange(48)]).reshape(4,12)

0 commit comments

Comments
 (0)