Skip to content

Commit 9128707

Browse files
committed
replace init-q-learning from aidojo-stable branch
1 parent 1b2b7e2 commit 9128707

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

agents/attackers/initialized_q_learning/initialized_q_agent.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ def initialize_q_value(self, action_counts, action_type):
114114
self.transition_probabilities.get(action.split('.')[-1], {}).get(action_type_str, 0) * count
115115
for action, count in action_counts.items()
116116
)
117-
# sum with for loop
118-
prob_sum = 0
119-
for action, count in action_counts.items():
120-
action_type_str = action.split('.')[-1]
121-
prob_sum += self.transition_probabilities.get(action_type_str, {}).get(action_type_str, 0) * count
117+
122118

123119
return prob_sum * 5
124120

@@ -215,45 +211,41 @@ def update_epsilon_with_decay(self, episode_number)->float:
215211
return new_eps
216212

217213
def play_game(self, observation, episode_num, testing=False):
218-
"""
219-
The main function for the gameplay. Handles the main interaction loop.
220-
"""
214+
if observation is None:
215+
observation = self.request_game_reset() or self.register()
216+
221217
num_steps = 0
222218
current_solution = []
223219

224-
# Run the whole episode
225220
while not observation.end:
226-
# Store steps so far
227221
num_steps += 1
228-
# Get next action. If we are not training, selection is different, so pass it as argument
229222
action, state_id = self.select_action(observation, testing)
230223
current_solution.append([action, None])
231224

232225
if args.store_actions:
233226
actions_logger.info(f"\tState:{observation.state}")
234227
actions_logger.info(f"\tEnd:{observation.end}")
235228
actions_logger.info(f"\tInfo:{observation.info}")
236-
self.logger.info(f"Action selected:{action}")
237-
# Perform the action and observe next observation
229+
self._logger.info(f"Action selected:{action}")
230+
238231
observation = self.make_step(action)
239-
240-
# Recompute the rewards
241232
observation = self.recompute_reward(observation)
233+
242234
if not testing:
243-
# If we are training update the Q-table
244235
self.q_values[state_id, action] += self.alpha * (observation.reward + self.gamma * self.max_action_q(observation)) - self.q_values[state_id, action]
236+
245237
if args.store_actions:
246238
actions_logger.info(f"\t State:{observation.state}")
247239
actions_logger.info(f"\t End:{observation.end}")
248240
actions_logger.info(f"\t Info:{observation.info}")
249-
# update epsilon value
241+
250242
if not testing:
251243
self.current_epsilon = self.update_epsilon_with_decay(episode_num)
252-
# Reset the episode
253-
_ = self.request_game_reset()
254-
# This will be the last observation played before the reset
244+
245+
self.request_game_reset()
255246
return observation, num_steps
256247

248+
257249
if __name__ == '__main__':
258250
parser = argparse.ArgumentParser('You can train the agent, or test it. \n Test is also to use the agent. \n During training and testing the performance is logged.')
259251
parser.add_argument("--host", help="Host where the game server is", default="127.0.0.1", action='store', required=False)

0 commit comments

Comments
 (0)