1- import sys
1+ from collections import defaultdict
2+
23import gym
34import numpy as np
4- from collections import defaultdict , deque
5- import matplotlib .pyplot as plt
65
76import check_test
87from plot_utils import plot_values
98
109env = gym .make ('CliffWalking-v0' )
1110
12- # print(env.action_space)
13- # print(env.observation_space)
14- #
15- # # define the optimal state-value function
16- # V_opt = np.zeros((4, 12))
17- # V_opt[0:13][0] = -np.arange(3, 15)[::-1]
18- # V_opt[0:13][1] = -np.arange(3, 15)[::-1] + 1
19- # V_opt[0:13][2] = -np.arange(3, 15)[::-1] + 2
20- # V_opt[3][0] = -13
21- #
22- # plot_values(V_opt)
23-
24-
25- def sarsa (env , num_episodes , alpha , gamma = 1.0 , epsilon_start = 1.0 ):
11+ def q_learning (env , num_episodes , alpha , gamma = 1.0 , epsilon_start = 1.0 ):
2612 # decide epsilon
2713 epsilon = epsilon_start
2814 epsilon_min = 0.1
29- epsilon_decay = 0.9999
15+ epsilon_decay = 0.997
16+
17+ nA = 4
3018
3119 # initialize action-value function (empty dictionary of arrays)
3220 Q = defaultdict (lambda : np .zeros (env .nA ))
@@ -35,53 +23,35 @@ def sarsa(env, num_episodes, alpha, gamma=1.0, epsilon_start=1.0):
3523 for i_episode in range (1 , num_episodes + 1 ):
3624
3725 # monitor progress
38- if i_episode % 499999 == 0 :
39- print ("\r Episode {}/{}" .format (i_episode , num_episodes ), end = "" )
40- print (str (Q ))
26+ # if i_episode % 1 == 0:
27+ # print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
28+ #
29+ # sys.stdout.flush()
4130
31+ if i_episode % 100 == 0 :
32+ print (f"Episode { i_episode } : Epsilon = { epsilon :.4f} " )
4233
4334 # set the value of epsilon
4435 epsilon = max (epsilon * epsilon_decay , epsilon_min )
4536
4637 # generate episode
47- episode = generate_episode (env = env , Q = Q , epsilon = epsilon , nA = 4 )
48-
49- Q = update_q (episode , Q , alpha , gamma )
50- return Q
51-
52-
53- def generate_episode (env , Q , epsilon , nA ):
54- episode = []
55- state , _ = env .reset ()
56- if isinstance (state , dict ):
57- state = tuple (sorted (state .items ()))
38+ state , _ = env .reset ()
5839
59- action = np . random . choice ( np . arange ( nA ),
60- p = get_probs (Q [state ], epsilon , nA )) if state in Q else env . action_space . sample ( )
40+ while True :
41+ action = np . random . choice ( np . arange ( nA ), p = get_probs (Q [state ], epsilon , nA ))
6142
62- while True :
63- if isinstance ( state , tuple ) :
64- state = state [ 0 ] # Extract actual state if (state, info) is returned
43+ next_state , reward , terminated , truncated , _ = env . step ( action ) # ✅ New API
44+ if next_state not in Q :
45+ Q [ next_state ] = np . zeros ( nA ) # ✅ Ensure Q-value initialization
6546
66- # ✅ Convert state to tuple if it’s a dictionary
67- if isinstance (state , dict ):
68- state = tuple (sorted (state .items ()))
47+ if terminated or truncated :
48+ break
6949
70- next_state , reward , terminated , truncated , _ = env .step (action ) # ✅ New API
71- if next_state not in Q :
72- Q [next_state ] = np .zeros (nA ) # ✅ Ensure Q-value initialization
50+ next_Q = 0 if terminated else np .max (Q [next_state ])
51+ Q [state ][action ] += alpha * (reward + gamma * next_Q - Q [state ][action ])
7352
74- next_action = np .random .choice (np .arange (nA ), p = get_probs (Q [next_state ], epsilon , nA ))
75-
76- episode .append ((state , action , reward ))
77-
78- if terminated or truncated :
79- break
80-
81- state = next_state # ✅ Track next state
82- action = next_action # ✅ Track next action
83-
84- return episode
53+ state = next_state
54+ return Q
8555
8656def get_probs (Q_s , epsilon , nA ):
8757 """ obtains the action probabilities corresponding to epsilon-greedy policy """
@@ -90,37 +60,16 @@ def get_probs(Q_s, epsilon, nA):
9060 policy_states [best_action ] = 1 - epsilon + (epsilon / nA )
9161 return policy_states
9262
93- def pick_action (epsilon , Q , next_state ):
94- if np .random .rand () < epsilon :
95- next_action = env .action_space .sample () # Explore (random action)
96- else :
97- next_action = np .argmax (Q [next_state ]) # Exploit (best action)
98-
99- return next_action
100-
101- def update_q (episode , Q , alpha , gamma ):
102- """ updates the action-value function estimate using the most recent episode """
103- states , actions , rewards = zip (* episode )
104- # prepare for discounting
105- for i in range (len (states ) - 1 ): # Ignore last step
106- state , action = states [i ], actions [i ]
107- next_state , next_action = states [i + 1 ], actions [i + 1 ] # ✅ Use episode step
108-
109- old_Q = Q [state ][action ]
110- next_Q = Q [next_state ][next_action ] # ✅ Correct SARSA update
111- Q [state ][action ] = old_Q + alpha * (rewards [i ] + gamma * next_Q - old_Q )
112- return Q
113-
11463
11564# obtain the estimated optimal policy and corresponding action-value function
116- Q_sarsa = sarsa (env , 500000 , .01 )
65+ Q_q_learning = q_learning (env , 5000 , .01 )
11766
11867# print the estimated optimal policy
119- policy_sarsa = np .array ([np .argmax (Q_sarsa [key ]) if key in Q_sarsa else - 1 for key in np .arange (48 )]).reshape (4 ,12 )
120- check_test .run_check ('td_control_check' , policy_sarsa )
68+ policy_q_learning = np .array ([np .argmax (Q_q_learning [key ]) if key in Q_q_learning else - 1 for key in np .arange (48 )]).reshape (4 ,12 )
69+ check_test .run_check ('td_control_check' , policy_q_learning )
12170print ("\n Estimated Optimal Policy (UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3, N/A = -1):" )
122- print (policy_sarsa )
71+ print (policy_q_learning )
12372
12473# plot the estimated optimal state-value function
125- V_sarsa = ([np .max (Q_sarsa [key ]) if key in Q_sarsa else 0 for key in np .arange (48 )])
126- plot_values (V_sarsa )
74+ V_q_learning = ([np .max (Q_q_learning [key ]) if key in Q_q_learning else 0 for key in np .arange (48 )])
75+ plot_values (V_q_learning )
0 commit comments