File tree Expand file tree Collapse file tree 1 file changed +12
-11
lines changed Expand file tree Collapse file tree 1 file changed +12
-11
lines changed Original file line number Diff line number Diff line change @@ -1074,11 +1074,19 @@ def disable_lm_head(model: torch.nn.Module):
10741074 """
10751075 _ , lm_head = get_embeddings (model )
10761076 if lm_head is not None :
1077- if not isinstance (lm_head , torch .nn .Linear ):
1078- raise NotImplementedError (
1079- f"Cannot disable LM head of type { lm_head .__class__ .__name__ } "
1080- )
1077+ logger .warning (
1078+ f"Attempted to disable lm_head of instance { model .__class__ .__name__ } , "
1079+ "but was unable to to find lm_head. This may lead to unexpected OOM."
1080+ )
1081+ yield
1082+ return
1083+
1084+ elif not isinstance (lm_head , torch .nn .Linear ):
1085+ logger .warning (f"Cannot disable LM head of type { lm_head .__class__ .__name__ } " )
1086+ yield
1087+ return
10811088
1089+ else :
10821090 dummy_weight = lm_head .weight .to ("meta" )
10831091
10841092 def dummy_forward (self , input : torch .Tensor ) -> torch .Tensor :
@@ -1087,13 +1095,6 @@ def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
10871095 with patch_attr (lm_head , "forward" , dummy_forward .__get__ (lm_head )):
10881096 yield
10891097
1090- else :
1091- logger .warning (
1092- f"Attempted to disable lm_head of instance { model .__class__ .__name__ } , "
1093- "but was unable to to find lm_head. This may lead to unexpected OOM."
1094- )
1095- yield
1096-
10971098
10981099@contextlib .contextmanager
10991100def patch_attr (base : object , attr : str , value : Any ):
You can’t perform that action at this time.
0 commit comments