Skip to content

Commit 31bee79

Browse files
committed
Make from_pretrained strict by default
1 parent 70776ea commit 31bee79

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

segmentation_models_pytorch/base/hub_mixin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def config(self) -> dict:
121121

122122

123123
@wraps(PyTorchModelHubMixin.from_pretrained)
124-
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
124+
def from_pretrained(
125+
pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs
126+
):
125127
config_path = Path(pretrained_model_name_or_path) / "config.json"
126128
if not config_path.exists():
127129
config_path = hf_hub_download(
@@ -137,7 +139,9 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
137139
import segmentation_models_pytorch as smp
138140

139141
model_class = getattr(smp, model_class_name)
140-
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
142+
return model_class.from_pretrained(
143+
pretrained_model_name_or_path, *args, **kwargs, strict=strict
144+
)
141145

142146

143147
def supports_config_loading(func):

0 commit comments

Comments
 (0)