Skip to content

Commit a9c28a2

Browse files
authored
Explicit seed setting (#1454)
1 parent af69c18 commit a9c28a2

File tree

10 files changed

+31
-25
lines changed

10 files changed

+31
-25
lines changed

docs/packages/utils.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ General machine learning utilities shared across Snorkel.
1212
filter_labels
1313
preds_to_probs
1414
probs_to_preds
15-
set_seed
1615
to_int_label_array

snorkel/labeling/model/label_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import random
23
from collections import Counter
34
from itertools import chain, permutations
45
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
@@ -13,7 +14,7 @@
1314
from snorkel.labeling.model.graph_utils import get_clique_tree
1415
from snorkel.labeling.model.logger import Logger
1516
from snorkel.types import Config
16-
from snorkel.utils import probs_to_preds, set_seed
17+
from snorkel.utils import probs_to_preds
1718
from snorkel.utils.config_utils import merge_config
1819
from snorkel.utils.lr_schedulers import LRSchedulerConfig
1920
from snorkel.utils.optimizers import OptimizerConfig
@@ -841,7 +842,9 @@ def fit(
841842
TrainConfig(), kwargs # type:ignore
842843
)
843844
# Update base config so that it includes all parameters
844-
set_seed(self.train_config.seed)
845+
random.seed(self.train_config.seed)
846+
np.random.seed(self.train_config.seed)
847+
torch.manual_seed(self.train_config.seed)
845848

846849
L_shift = L_train + 1 # convert to {0, 1, ..., k}
847850
if L_shift.max() > self.cardinality:

snorkel/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
filter_labels,
55
preds_to_probs,
66
probs_to_preds,
7-
set_seed,
87
to_int_label_array,
98
)

snorkel/utils/core.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
import hashlib
2-
import random
32
from typing import Dict, List
43

54
import numpy as np
6-
import torch
7-
8-
9-
def set_seed(seed: int) -> None:
10-
"""Set the Python, NumPy, and PyTorch random seeds."""
11-
random.seed(seed)
12-
np.random.seed(seed)
13-
torch.manual_seed(seed)
145

156

167
def _hash(i: int) -> int:

test/classification/test_classifier_convergence.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import unittest
23
from typing import List
34

@@ -16,7 +17,6 @@
1617
Task,
1718
Trainer,
1819
)
19-
from snorkel.utils import set_seed
2020

2121
N_TRAIN = 1000
2222
N_VALID = 300
@@ -26,7 +26,9 @@ class ClassifierConvergenceTest(unittest.TestCase):
2626
@classmethod
2727
def setUpClass(cls):
2828
# Ensure deterministic runs
29-
set_seed(123)
29+
random.seed(123)
30+
np.random.seed(123)
31+
torch.manual_seed(123)
3032

3133
@pytest.mark.complex
3234
def test_convergence(self):

test/classification/test_multitask_classifier.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import random
23
import tempfile
34
import unittest
45

@@ -14,7 +15,6 @@
1415
Operation,
1516
Task,
1617
)
17-
from snorkel.utils import set_seed
1818

1919
NUM_EXAMPLES = 10
2020
BATCH_SIZE = 2
@@ -28,7 +28,9 @@ def setUpClass(cls):
2828
cls.dataloader = create_dataloader("task1")
2929

3030
def setUp(self):
31-
set_seed(123)
31+
random.seed(123)
32+
np.random.seed(123)
33+
torch.manual_seed(123)
3234

3335
def test_onetask_model(self):
3436
model = MultitaskClassifier(tasks=[self.task1])

test/classification/training/schedulers/test_schedulers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import random
12
import unittest
23

4+
import numpy as np
35
import torch
46

57
from snorkel.classification import DictDataLoader, DictDataset
68
from snorkel.classification.training.schedulers import (
79
SequentialScheduler,
810
ShuffledScheduler,
911
)
10-
from snorkel.utils import set_seed
1112

1213
dataset1 = DictDataset(
1314
"d1",
@@ -37,7 +38,9 @@ def test_sequential(self):
3738
self.assertEqual(data, sorted(data))
3839

3940
def test_shuffled(self):
40-
set_seed(123)
41+
random.seed(123)
42+
np.random.seed(123)
43+
torch.manual_seed(123)
4144
scheduler = ShuffledScheduler()
4245
data = []
4346
for (batch, dl) in scheduler.get_batches(dataloaders):

test/labeling/test_convergence.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import random
12
import unittest
23

34
import numpy as np
45
import pandas as pd
56
import pytest
7+
import torch
68

79
from snorkel.labeling import (
810
LabelingFunction,
@@ -12,7 +14,6 @@
1214
)
1315
from snorkel.preprocess import preprocessor
1416
from snorkel.types import DataPoint
15-
from snorkel.utils import set_seed
1617

1718

1819
def create_data(n: int) -> pd.DataFrame:
@@ -61,7 +62,9 @@ class LabelingConvergenceTest(unittest.TestCase):
6162
@classmethod
6263
def setUpClass(cls):
6364
# Ensure deterministic runs
64-
set_seed(123)
65+
random.seed(123)
66+
np.random.seed(123)
67+
torch.manual_seed(123)
6568

6669
# Create raw data
6770
cls.N_TRAIN = 1500

test/slicing/test_convergence.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import unittest
23
from typing import List
34

@@ -23,7 +24,6 @@
2324
slicing_function,
2425
)
2526
from snorkel.types import DataPoint
26-
from snorkel.utils import set_seed
2727

2828

2929
# Define SFs specifying points inside a circle
@@ -55,7 +55,9 @@ class SlicingConvergenceTest(unittest.TestCase):
5555
@classmethod
5656
def setUpClass(cls):
5757
# Ensure deterministic runs
58-
set_seed(123)
58+
random.seed(123)
59+
np.random.seed(123)
60+
torch.manual_seed(123)
5961

6062
# Create raw data
6163
cls.N_TRAIN = 1500

test/slicing/test_slice_combiner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import random
22
import unittest
33

4+
import numpy as np
45
import torch
56

67
from snorkel.slicing import SliceCombinerModule
7-
from snorkel.utils import set_seed
88

99

1010
class SliceCombinerTest(unittest.TestCase):
1111
@classmethod
1212
def setUpClass(cls):
13-
set_seed(123)
13+
random.seed(123)
14+
np.random.seed(123)
15+
torch.manual_seed(123)
1416

1517
def test_forward_shape(self):
1618
"""Test that the reweight representation shape matches expected feature size."""

0 commit comments

Comments
 (0)