1+ import logging
2+ from typing import Optional , TYPE_CHECKING , Union
3+
4+ import hydra
5+ import pytorch_lightning as pl
6+ import torch
7+ from omegaconf import DictConfig
8+
9+ from lightning_transformers .core import TransformerDataModule
10+ from lightning_transformers .core .data import TokenizerDataModule
11+
12+ if TYPE_CHECKING :
13+ # avoid circular imports
14+ from lightning_transformers .core import TaskTransformer
15+
16+
17+ class Instantiator :
18+
19+ def model (self , * args , ** kwargs ):
20+ raise NotImplementedError ("Child class must implement method" )
21+
22+ def optimizer (self , * args , ** kwargs ):
23+ raise NotImplementedError ("Child class must implement method" )
24+
25+ def scheduler (self , * args , ** kwargs ):
26+ raise NotImplementedError ("Child class must implement method" )
27+
28+ def data_module (self , * args , ** kwargs ):
29+ raise NotImplementedError ("Child class must implement method" )
30+
31+ def logger (self , * args , ** kwargs ):
32+ raise NotImplementedError ("Child class must implement method" )
33+
34+ def trainer (self , * args , ** kwargs ):
35+ raise NotImplementedError ("Child class must implement method" )
36+
37+ def instantiate (self , * args , ** kwargs ):
38+ raise NotImplementedError ("Child class must implement method" )
39+
40+
41+ class HydraInstantiator (Instantiator ):
42+
43+ def model (
44+ self ,
45+ cfg : DictConfig ,
46+ model_data_kwargs : Optional [DictConfig ] = None ,
47+ tokenizer : Optional [DictConfig ] = None ,
48+ pipeline_kwargs : Optional [DictConfig ] = None
49+ ) -> "TaskTransformer" :
50+ if model_data_kwargs is None :
51+ model_data_kwargs = {}
52+ model_data_kwargs = dict (model_data_kwargs ) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
53+
54+ # use `model_data_kwargs` to pass `tokenizer` and `pipeline_kwargs`
55+ # as not all models might contain these parameters.
56+ if tokenizer :
57+ model_data_kwargs ["tokenizer" ] = self .instantiate (tokenizer )
58+ if pipeline_kwargs :
59+ model_data_kwargs ["pipeline_kwargs" ] = pipeline_kwargs
60+
61+ return self .instantiate (cfg , instantiator = self , ** model_data_kwargs )
62+
63+ def optimizer (self , model : torch .nn .Module , cfg : DictConfig ) -> torch .optim .Optimizer :
64+ no_decay = ["bias" , "LayerNorm.weight" ]
65+ grouped_parameters = [
66+ {
67+ "params" : [
68+ p for n , p in model .named_parameters () if not any (nd in n for nd in no_decay ) and p .requires_grad
69+ ],
70+ "weight_decay" : cfg .weight_decay ,
71+ },
72+ {
73+ "params" : [
74+ p for n , p in model .named_parameters () if any (nd in n for nd in no_decay ) and p .requires_grad
75+ ],
76+ "weight_decay" : 0.0 ,
77+ },
78+ ]
79+ return self .instantiate (cfg , grouped_parameters )
80+
81+ def scheduler (self , cfg : DictConfig , optimizer : torch .optim .Optimizer ) -> torch .optim .lr_scheduler ._LRScheduler :
82+ return self .instantiate (cfg , optimizer = optimizer )
83+
84+ def data_module (
85+ self ,
86+ cfg : DictConfig ,
87+ tokenizer : Optional [DictConfig ] = None ,
88+ ) -> Union [TransformerDataModule , TokenizerDataModule ]:
89+ if tokenizer :
90+ return self .instantiate (cfg , tokenizer = self .instantiate (tokenizer ))
91+ return self .instantiate (cfg )
92+
93+ def logger (self , cfg : DictConfig ) -> Optional [logging .Logger ]:
94+ if cfg .get ("log" ):
95+ if isinstance (cfg .trainer .logger , bool ):
96+ return cfg .trainer .logger
97+ return self .instantiate (cfg .trainer .logger )
98+
99+ def trainer (self , cfg : DictConfig , ** kwargs ) -> pl .Trainer :
100+ return self .instantiate (cfg , ** kwargs )
101+
102+ def instantiate (self , * args , ** kwargs ):
103+ return hydra .utils .instantiate (* args , ** kwargs )
0 commit comments