@@ -47,9 +47,7 @@ def model(
4747 ) -> "TaskTransformer" :
4848 if model_data_kwargs is None :
4949 model_data_kwargs = {}
50- model_data_kwargs = dict (
51- model_data_kwargs
52- ) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
50+ model_data_kwargs = dict (model_data_kwargs ) # avoid ConfigKeyError: Key 'tokenizer' is not in struct`
5351
5452 # use `model_data_kwargs` to pass `tokenizer` and `pipeline_kwargs`
5553 # as not all models might contain these parameters.
@@ -60,33 +58,25 @@ def model(
6058
6159 return self .instantiate (cfg , instantiator = self , ** model_data_kwargs )
6260
63- def optimizer (
64- self , model : torch .nn .Module , cfg : DictConfig
65- ) -> torch .optim .Optimizer :
61+ def optimizer (self , model : torch .nn .Module , cfg : DictConfig ) -> torch .optim .Optimizer :
6662 no_decay = ["bias" , "LayerNorm.weight" ]
6763 grouped_parameters = [
6864 {
6965 "params" : [
70- p
71- for n , p in model .named_parameters ()
72- if not any (nd in n for nd in no_decay ) and p .requires_grad
66+ p for n , p in model .named_parameters () if not any (nd in n for nd in no_decay ) and p .requires_grad
7367 ],
7468 "weight_decay" : cfg .weight_decay ,
7569 },
7670 {
7771 "params" : [
78- p
79- for n , p in model .named_parameters ()
80- if any (nd in n for nd in no_decay ) and p .requires_grad
72+ p for n , p in model .named_parameters () if any (nd in n for nd in no_decay ) and p .requires_grad
8173 ],
8274 "weight_decay" : 0.0 ,
8375 },
8476 ]
8577 return self .instantiate (cfg , grouped_parameters )
8678
87- def scheduler (
88- self , cfg : DictConfig , optimizer : torch .optim .Optimizer
89- ) -> torch .optim .lr_scheduler ._LRScheduler :
79+ def scheduler (self , cfg : DictConfig , optimizer : torch .optim .Optimizer ) -> torch .optim .lr_scheduler ._LRScheduler :
9080 return self .instantiate (cfg , optimizer = optimizer )
9181
9282 def data_module (
0 commit comments