-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain_gs.py
More file actions
30 lines (24 loc) · 881 Bytes
/
main_gs.py
File metadata and controls
30 lines (24 loc) · 881 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from gradeadreamer.gs import Trainer
import torch
import numpy as np
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=False, default="./configs/gs.yaml", help="path to the yaml config file")
parser.add_argument("--gpu", required=False, default="0")
parser.add_argument("--prompt", required=True, help="prompt")
args, extras = parser.parse_known_args()
# override default config from cli
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
if "gpu_id" not in opt:
opt.gpu_id = args.gpu
opt.prompt = args.prompt
# seed
seed = opt.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
# train
trainer = Trainer(opt)
trainer.train(opt.iters)