Skip to content
This repository was archived by the owner on Aug 9, 2023. It is now read-only.

Commit b5ef951

Browse files
author
Campbells
authored
Merge pull request #273 from wellcometrust/feature/speed-up-tests
Feature/speed up tests
2 parents c7619dc + 7ee1690 commit b5ef951

23 files changed

+99
-74
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ install:
1111
- pip install tox-travis
1212

1313
env:
14+
jobs:
15+
- TEST_SUITE='bert'
16+
- TEST_SUITE='not bert'
17+
1418
global:
1519
- TF_CPP_MIN_LOG_LEVEL=2
1620

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
addopts = --strict-markers
33
markers =
44
integration: integration tests
5+
bert: tests that use bert (usually heavy tests)

tests/test_bert_classifier.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
# encoding: utf-8
2+
import pytest
23
import tempfile
34

45
import numpy as np
56

67
from wellcomeml.ml.bert_classifier import BertClassifier
78

89

9-
def test_multilabel():
10+
@pytest.fixture
11+
def multilabel_bert(scope='module'):
12+
model = BertClassifier()
13+
model._init_model(num_labels=4)
14+
15+
return model
16+
17+
18+
@pytest.mark.bert
19+
def test_multilabel(multilabel_bert):
1020
X = [
1121
"One and two",
1222
"One only",
@@ -22,7 +32,7 @@ def test_multilabel():
2232
[0, 1, 1, 0]
2333
])
2434

25-
model = BertClassifier()
35+
model = multilabel_bert
2636
model.fit(X, Y)
2737
Y_pred = model.predict(X)
2838
Y_prob_pred = model.predict_proba(X)
@@ -35,6 +45,7 @@ def test_multilabel():
3545
assert model.losses[0] > model.losses[-1]
3646

3747

48+
@pytest.mark.bert
3849
def test_multiclass():
3950
X = [
4051
"One oh yes",
@@ -64,6 +75,7 @@ def test_multiclass():
6475
assert model.losses[0] > model.losses[-1]
6576

6677

78+
@pytest.mark.bert
6779
def test_scibert():
6880
X = [
6981
"One and two",
@@ -93,7 +105,8 @@ def test_scibert():
93105
assert model.losses[0] > model.losses[-1]
94106

95107

96-
def test_save_load():
108+
@pytest.mark.bert
109+
def test_save_load(multilabel_bert):
97110
X = [
98111
"One and two",
99112
"One only",
@@ -109,7 +122,8 @@ def test_save_load():
109122
[0, 1, 1, 0]
110123
])
111124

112-
model = BertClassifier()
125+
model = multilabel_bert
126+
model.epochs = 1 # Only need to fit 1 epoch here really, because we're testing save
113127
model.fit(X, Y)
114128

115129
with tempfile.TemporaryDirectory() as tmp_path:
@@ -119,10 +133,5 @@ def test_save_load():
119133

120134
Y_pred = loaded_model.predict(X)
121135
Y_prob_pred = loaded_model.predict_proba(X)
122-
assert Y_pred.sum() != 0
123-
assert Y_pred.sum() != Y.size
124-
assert Y_prob_pred.max() <= 1
125-
assert Y_prob_pred.min() >= 0
136+
assert Y_prob_pred.sum() >= 0
126137
assert Y_pred.shape == Y.shape
127-
assert Y_prob_pred.shape == Y.shape
128-
assert model.losses[0] > model.losses[-1]

tests/test_bert_vectorizer.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# encoding: utf-8
22
import pytest
33

4-
from wellcomeml.ml import bert_vectorizer
4+
from wellcomeml.ml.bert_vectorizer import BertVectorizer
55

66
EMBEDDING_TYPES = [
77
"mean_second_to_last",
@@ -12,43 +12,54 @@
1212
]
1313

1414

15-
def test_embed_one_sentence():
15+
@pytest.fixture
16+
def vec(scope='module'):
17+
vectorizer = BertVectorizer()
18+
19+
vectorizer.fit()
20+
return vectorizer
21+
22+
23+
@pytest.mark.bert
24+
def test_fit_transform_works(vec):
1625
X = ["This is a sentence"]
1726

18-
for embedding in EMBEDDING_TYPES:
19-
vec = bert_vectorizer.BertVectorizer(sentence_embedding=embedding)
20-
X_embed = vec.fit_transform(X)
21-
assert(X_embed.shape == (1, 768))
27+
assert vec.fit_transform(X).shape == (1, 768)
2228

2329

24-
def test_embed_two_sentences():
30+
@pytest.mark.bert
31+
def test_embed_two_sentences(vec):
2532
X = [
2633
"This is a sentence",
2734
"This is another one"
2835
]
2936

3037
for embedding in EMBEDDING_TYPES:
31-
vec = bert_vectorizer.BertVectorizer(sentence_embedding=embedding)
32-
X_embed = vec.fit_transform(X)
33-
assert(X_embed.shape == (2, 768))
38+
vec.sentence_embedding = embedding
39+
X_embed = vec.transform(X, verbose=False)
40+
assert X_embed.shape == (2, 768)
3441

3542

36-
def test_embed_long_sentence():
43+
@pytest.mark.bert
44+
def test_embed_long_sentence(vec):
3745
X = ["This is a sentence"*500]
3846

3947
for embedding in EMBEDDING_TYPES:
40-
vec = bert_vectorizer.BertVectorizer(sentence_embedding=embedding)
41-
X_embed = vec.fit_transform(X)
42-
assert(X_embed.shape == (1, 768))
48+
vec.sentence_embedding = embedding
49+
X_embed = vec.transform(X, verbose=False)
50+
assert X_embed.shape == (1, 768)
4351

4452

53+
@pytest.mark.bert
4554
def test_embed_scibert():
4655
X = ["This is a sentence"]
56+
vec = BertVectorizer(pretrained='scibert')
57+
vec.fit()
58+
4759
for embedding in EMBEDDING_TYPES:
48-
vec = bert_vectorizer.BertVectorizer(pretrained='scibert',
49-
sentence_embedding=embedding)
50-
X_embed = vec.fit_transform(X)
51-
assert(X_embed.shape == (1, 768))
60+
vec.sentence_embedding = embedding
61+
X_embed = vec.transform(X, verbose=False)
62+
assert X_embed.shape == (1, 768)
5263

5364

5465
@pytest.mark.skip("Reason: Build killed or stalls. Issue #200")
@@ -58,11 +69,11 @@ def test_save_and_load(tmpdir):
5869
X = ["This is a sentence"]
5970
for pretrained in ['bert', 'scibert']:
6071
for embedding in EMBEDDING_TYPES:
61-
vec = bert_vectorizer.BertVectorizer(
72+
vec = BertVectorizer(
6273
pretrained=pretrained,
6374
sentence_embedding=embedding
6475
)
65-
X_embed = vec.fit_transform(X)
76+
X_embed = vec.fit_transform(X, verbose=False)
6677

6778
vec.save_transformed(str(tmpfile), X_embed)
6879

tests/test_bilstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tempfile
22

33
from wellcomeml.ml.bilstm import BiLSTMClassifier
4-
from wellcomeml.ml import KerasVectorizer
4+
from wellcomeml.ml.keras_vectorizer import KerasVectorizer
55
from sklearn.pipeline import Pipeline
66
from scipy.sparse import csr_matrix
77
import numpy as np

tests/test_clustering.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from wellcomeml.ml import TextClustering
3+
from wellcomeml.ml.clustering import TextClustering
44

55

66
@pytest.mark.parametrize("reducer,cluster_reduced", [("tsne", True),
@@ -35,11 +35,11 @@ def test_parameter_search(reducer):
3535
'Francis Harry Crick']
3636

3737
param_grid = {
38-
'reducer': {'min_dist': [0.0, 0.2],
39-
'n_neighbors': [2, 3, 5],
38+
'reducer': {'min_dist': [0.0],
39+
'n_neighbors': [2],
4040
'metric': ['cosine', 'euclidean']},
41-
'clustering': {'min_samples': [2, 5],
42-
'eps': [0.5, 1, 1.5]}
41+
'clustering': {'min_samples': [2],
42+
'eps': [0.5]}
4343
}
4444

4545
best_params = cluster.optimise(X, param_grid=param_grid,

tests/test_cnn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tempfile
22

3-
from wellcomeml.ml import CNNClassifier, KerasVectorizer
3+
from wellcomeml.ml.cnn import CNNClassifier
4+
from wellcomeml.ml.keras_vectorizer import KerasVectorizer
45
from sklearn.pipeline import Pipeline
56
from scipy.sparse import csr_matrix
67
import tensorflow as tf

tests/test_doc2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from wellcomeml.ml import Doc2VecVectorizer
1+
from wellcomeml.ml.doc2vec_vectorizer import Doc2VecVectorizer
22

33

44
def test_fit_transform():

tests/test_entity_linking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from wellcomeml.ml import SimilarityEntityLinker
2+
from wellcomeml.ml.similarity_entity_linking import SimilarityEntityLinker
33

44

55
@pytest.fixture(scope="module")

tests/test_frequency_vectorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# encoding: utf-8
2-
from wellcomeml.ml import WellcomeTfidf
2+
from wellcomeml.ml.frequency_vectorizer import WellcomeTfidf
33

44

55
def test_tf_idf_dispatch():

0 commit comments

Comments
 (0)