|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import yaml |
| 4 | + |
| 5 | +global_print_hparams = True |
| 6 | +hparams = {} |
| 7 | + |
| 8 | + |
| 9 | +class Args: |
| 10 | + def __init__(self, **kwargs): |
| 11 | + for k, v in kwargs.items(): |
| 12 | + self.__setattr__(k, v) |
| 13 | + |
| 14 | + |
| 15 | +def override_config(old_config: dict, new_config: dict): |
| 16 | + for k, v in new_config.items(): |
| 17 | + if isinstance(v, dict) and k in old_config: |
| 18 | + override_config(old_config[k], new_config[k]) |
| 19 | + else: |
| 20 | + old_config[k] = v |
| 21 | + |
| 22 | + |
| 23 | +def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): |
| 24 | + if config == '': |
| 25 | + parser = argparse.ArgumentParser(description='neural music') |
| 26 | + parser.add_argument('--config', type=str, default='', |
| 27 | + help='location of the data corpus') |
| 28 | + parser.add_argument('--exp_name', type=str, default='', help='exp_name') |
| 29 | + parser.add_argument('--hparams', type=str, default='', |
| 30 | + help='location of the data corpus') |
| 31 | + parser.add_argument('--infer', action='store_true', help='infer') |
| 32 | + parser.add_argument('--validate', action='store_true', help='validate') |
| 33 | + parser.add_argument('--reset', action='store_true', help='reset hparams') |
| 34 | + parser.add_argument('--debug', action='store_true', help='debug') |
| 35 | + args, unknown = parser.parse_known_args() |
| 36 | + else: |
| 37 | + args = Args(config=config, exp_name=exp_name, hparams=hparams_str, |
| 38 | + infer=False, validate=False, reset=False, debug=False) |
| 39 | + args_work_dir = '' |
| 40 | + if args.exp_name != '': |
| 41 | + args.work_dir = args.exp_name |
| 42 | + args_work_dir = f'checkpoints/{args.work_dir}' |
| 43 | + |
| 44 | + config_chains = [] |
| 45 | + loaded_config = set() |
| 46 | + |
| 47 | + def load_config(config_fn): # deep first |
| 48 | + if(config_fn.startswith("/")): |
| 49 | + config_fn_path=os.path.abspath(config_fn[1:]) |
| 50 | + else: |
| 51 | + config_fn_path=config_fn |
| 52 | + with open(config_fn_path, encoding='utf-8') as f: |
| 53 | + hparams_ = yaml.safe_load(f) |
| 54 | + loaded_config.add(config_fn) |
| 55 | + if 'base_config' in hparams_: |
| 56 | + ret_hparams = {} |
| 57 | + if not isinstance(hparams_['base_config'], list): |
| 58 | + hparams_['base_config'] = [hparams_['base_config']] |
| 59 | + for c in hparams_['base_config']: |
| 60 | + if c not in loaded_config: |
| 61 | + if c.startswith('.'): |
| 62 | + c = f'{os.path.dirname(config_fn)}/{c}' |
| 63 | + c = os.path.normpath(c) |
| 64 | + override_config(ret_hparams, load_config(c)) |
| 65 | + override_config(ret_hparams, hparams_) |
| 66 | + else: |
| 67 | + ret_hparams = hparams_ |
| 68 | + config_chains.append(config_fn) |
| 69 | + return ret_hparams |
| 70 | + |
| 71 | + global hparams |
| 72 | + assert args.config != '' or args_work_dir != '' |
| 73 | + saved_hparams = {} |
| 74 | + if args_work_dir != 'checkpoints/': |
| 75 | + ckpt_config_path = f'{args_work_dir}/config.yaml' |
| 76 | + if os.path.exists(ckpt_config_path): |
| 77 | + try: |
| 78 | + with open(ckpt_config_path, encoding='utf-8') as f: |
| 79 | + saved_hparams.update(yaml.safe_load(f)) |
| 80 | + except: |
| 81 | + pass |
| 82 | + if args.config == '': |
| 83 | + args.config = ckpt_config_path |
| 84 | + |
| 85 | + hparams_ = {} |
| 86 | + |
| 87 | + hparams_.update(load_config(args.config)) |
| 88 | + |
| 89 | + if not args.reset: |
| 90 | + hparams_.update(saved_hparams) |
| 91 | + hparams_['work_dir'] = args_work_dir |
| 92 | + |
| 93 | + if args.hparams != "": |
| 94 | + for new_hparam in args.hparams.split(","): |
| 95 | + k, v = new_hparam.split("=") |
| 96 | + if v in ['True', 'False'] or type(hparams_[k]) == bool: |
| 97 | + hparams_[k] = eval(v) |
| 98 | + else: |
| 99 | + hparams_[k] = type(hparams_[k])(v) |
| 100 | + |
| 101 | + if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: |
| 102 | + os.makedirs(hparams_['work_dir'], exist_ok=True) |
| 103 | + with open(ckpt_config_path, 'w', encoding='utf-8') as f: |
| 104 | + yaml.safe_dump(hparams_, f) |
| 105 | + |
| 106 | + hparams_['infer'] = args.infer |
| 107 | + hparams_['debug'] = args.debug |
| 108 | + hparams_['validate'] = args.validate |
| 109 | + global global_print_hparams |
| 110 | + if global_hparams: |
| 111 | + hparams.clear() |
| 112 | + hparams.update(hparams_) |
| 113 | + |
| 114 | + if print_hparams and global_print_hparams and global_hparams: |
| 115 | + print('| Hparams chains: ', config_chains) |
| 116 | + print('| Hparams: ') |
| 117 | + for i, (k, v) in enumerate(sorted(hparams_.items())): |
| 118 | + print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") |
| 119 | + print("") |
| 120 | + global_print_hparams = False |
| 121 | + # print(hparams_.keys()) |
| 122 | + if hparams.get('exp_name') is None: |
| 123 | + hparams['exp_name'] = args.exp_name |
| 124 | + if hparams_.get('exp_name') is None: |
| 125 | + hparams_['exp_name'] = args.exp_name |
| 126 | + return hparams_ |
0 commit comments