-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
120 lines (100 loc) · 3.69 KB
/
train.py
File metadata and controls
120 lines (100 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# train.py
import argparse
import os
import torch
from rldoom.configs import make_config
from rldoom.utils.logger import Logger
from rldoom.utils.seeding import set_seed
from rldoom.agents.dqn import DQNAgent
from rldoom.agents.ddqn import DDQNAgent
from rldoom.agents.dddqn import DDDQNAgent
from rldoom.agents.rainbow import RainbowAgent
from rldoom.agents.reinforce import ReinforceAgent
from rldoom.agents.a2c import A2CAgent
from rldoom.agents.a3c import A3CAgent
from rldoom.agents.ppo import PPOAgent
from rldoom.agents.trpo import TRPOAgent
from rldoom.trainers.offpolicy import train_offpolicy
from rldoom.trainers.onpolicy import train_onpolicy
def build_agent(algo: str, obs_shape, num_actions: int, cfg, device):
"""Factory method that maps algo string to the corresponding Agent class."""
# --- Off-policy family ---
if algo == "dqn":
return DQNAgent(obs_shape, num_actions, cfg, device)
if algo == "ddqn":
return DDQNAgent(obs_shape, num_actions, cfg, device)
if algo in ("dddqn", "dddqn_tuned"):
# Tuned variant uses the same agent class but different hyperparameters in cfg.
return DDDQNAgent(obs_shape, num_actions, cfg, device)
if algo == "rainbow":
return RainbowAgent(obs_shape, num_actions, cfg, device)
# --- On-policy family ---
if algo in ("reinforce", "reinforce_tuned"):
return ReinforceAgent(obs_shape, num_actions, cfg, device)
if algo in ("a2c", "a2c_tuned"):
return A2CAgent(obs_shape, num_actions, cfg, device)
if algo == "a3c":
return A3CAgent(obs_shape, num_actions, cfg, device)
if algo in ("ppo", "ppo_tuned"):
return PPOAgent(obs_shape, num_actions, cfg, device)
if algo == "trpo":
return TRPOAgent(obs_shape, num_actions, cfg, device)
raise ValueError(f"Unknown algorithm: {algo}")
def main():
"""Main training entrypoint.
- Parses CLI arguments.
- Builds config from YAML.
- Sets random seeds and selects device.
- Instantiates the requested Agent.
- Dispatches to on-policy or off-policy trainer.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--algo",
type=str,
default="dqn",
choices=[
# baseline off-policy
"dqn",
"ddqn",
"dddqn",
"rainbow",
# baseline on-policy
"reinforce",
"a2c",
"a3c",
"ppo",
"trpo",
# tuned variants
"reinforce_tuned",
"a2c_tuned",
"dddqn_tuned",
"ppo_tuned",
],
help="Name of the RL algorithm to train.",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
args = parser.parse_args()
# Build config object from YAML (env + train + logging + algo hyperparams)
cfg = make_config(args.algo, args.seed)
# Set all random seeds (Python, NumPy, Torch, envs if needed)
set_seed(cfg.seed)
# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Observation shape (C, H, W)
obs_shape = (cfg.stack_size, cfg.frame_size, cfg.frame_size)
# Deadly Corridor has 7 discrete actions
num_actions = 7
# Instantiate Agent
agent = build_agent(cfg.algo, obs_shape, num_actions, cfg, device)
# Logger wrapper (handles wandb + console)
logger = Logger(cfg)
# Dispatch to the appropriate trainer
if cfg.algo_type == "offpolicy":
train_offpolicy(agent, cfg, logger)
elif cfg.algo_type == "onpolicy":
train_onpolicy(agent, cfg, logger)
else:
raise ValueError(f"Unknown algo_type: {cfg.algo_type}")
if __name__ == "__main__":
main()