Skip to content

Commit 9c9d52f

Browse files
committed
manual seed
1 parent a4ad5da commit 9c9d52f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def get_untrained_model_with_inputs(
228228
f"and use_pretrained=True."
229229
)
230230

231+
seed = int(os.environ.get("SEED", "17"))
232+
torch.manual_seed(seed)
231233
try:
232234
if type(config) is dict:
233235
model = cls_model(**config)
@@ -239,6 +241,8 @@ def get_untrained_model_with_inputs(
239241
) from e
240242

241243
# input kwargs
244+
seed = int(os.environ.get("SEED", "17")) + 1
245+
torch.manual_seed(seed)
242246
kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
243247
if verbose:
244248
print(f"[get_untrained_model_with_inputs] use fct={fct}")

0 commit comments

Comments
 (0)