@@ -75,7 +75,7 @@ def __init__(
7575 dropout : float ,
7676 learning_rate : float ,
7777 weight_decay : float ,
78- class_weights : torch . Tensor ,
78+ class_weights : list [ float ] ,
7979 # special params
8080 embedding_model_name : str ,
8181 embedding_dim : int ,
@@ -200,20 +200,23 @@ def _val_test_step(self, prefix: str, batch, batch_idx: int) -> torch.Tensor:
200200 )
201201 return loss
202202
203+ @torch .no_grad ()
203204 def validation_step (self , batch , batch_idx ):
204205 return self ._val_test_step (
205206 prefix = "eval" ,
206207 batch = batch ,
207208 batch_idx = batch_idx ,
208209 )
209210
211+ @torch .no_grad ()
210212 def test_step (self , batch , batch_idx ):
211213 return self ._val_test_step (
212214 prefix = "test" ,
213215 batch = batch ,
214216 batch_idx = batch_idx ,
215217 )
216218
219+ @torch .no_grad ()
217220 def predict_step (self , batch : dict [str , Any ], batch_idx : int ) -> Any :
218221 # Get predictions and ground truth tags
219222 predictions = self (sentences = batch ["sentences" ], mask = batch ["mask" ])
@@ -544,7 +547,7 @@ def train(
544547 dropout = parameters .dropout ,
545548 learning_rate = parameters .learning_rate ,
546549 weight_decay = parameters .weight_decay ,
547- class_weights = torch . tensor ( class_weights , dtype = torch . float32 ) ,
550+ class_weights = class_weights ,
548551 id2label = id2label ,
549552 label2id = {v : k for k , v in id2label .items ()},
550553 )
0 commit comments