11# encoding: utf-8
22import pytest
33
4- from wellcomeml .ml import bert_vectorizer
4+ from wellcomeml .ml . bert_vectorizer import BertVectorizer
55
66EMBEDDING_TYPES = [
77 "mean_second_to_last" ,
1212]
1313
1414
15- def test_embed_one_sentence ():
15+ @pytest .fixture
16+ def vec (scope = 'module' ):
17+ vectorizer = BertVectorizer ()
18+
19+ vectorizer .fit ()
20+ return vectorizer
21+
22+
23+ @pytest .mark .bert
24+ def test_fit_transform_works (vec ):
1625 X = ["This is a sentence" ]
1726
18- for embedding in EMBEDDING_TYPES :
19- vec = bert_vectorizer .BertVectorizer (sentence_embedding = embedding )
20- X_embed = vec .fit_transform (X )
21- assert (X_embed .shape == (1 , 768 ))
27+ assert vec .fit_transform (X ).shape == (1 , 768 )
2228
2329
24- def test_embed_two_sentences ():
30+ @pytest .mark .bert
31+ def test_embed_two_sentences (vec ):
2532 X = [
2633 "This is a sentence" ,
2734 "This is another one"
2835 ]
2936
3037 for embedding in EMBEDDING_TYPES :
31- vec = bert_vectorizer . BertVectorizer ( sentence_embedding = embedding )
32- X_embed = vec .fit_transform ( X )
33- assert ( X_embed .shape == (2 , 768 ) )
38+ vec . sentence_embedding = embedding
39+ X_embed = vec .transform ( X , verbose = False )
40+ assert X_embed .shape == (2 , 768 )
3441
3542
36- def test_embed_long_sentence ():
43+ @pytest .mark .bert
44+ def test_embed_long_sentence (vec ):
3745 X = ["This is a sentence" * 500 ]
3846
3947 for embedding in EMBEDDING_TYPES :
40- vec = bert_vectorizer . BertVectorizer ( sentence_embedding = embedding )
41- X_embed = vec .fit_transform ( X )
42- assert ( X_embed .shape == (1 , 768 ) )
48+ vec . sentence_embedding = embedding
49+ X_embed = vec .transform ( X , verbose = False )
50+ assert X_embed .shape == (1 , 768 )
4351
4452
53+ @pytest .mark .bert
4554def test_embed_scibert ():
4655 X = ["This is a sentence" ]
56+ vec = BertVectorizer (pretrained = 'scibert' )
57+ vec .fit ()
58+
4759 for embedding in EMBEDDING_TYPES :
48- vec = bert_vectorizer .BertVectorizer (pretrained = 'scibert' ,
49- sentence_embedding = embedding )
50- X_embed = vec .fit_transform (X )
51- assert (X_embed .shape == (1 , 768 ))
60+ vec .sentence_embedding = embedding
61+ X_embed = vec .transform (X , verbose = False )
62+ assert X_embed .shape == (1 , 768 )
5263
5364
5465@pytest .mark .skip ("Reason: Build killed or stalls. Issue #200" )
@@ -58,11 +69,11 @@ def test_save_and_load(tmpdir):
5869 X = ["This is a sentence" ]
5970 for pretrained in ['bert' , 'scibert' ]:
6071 for embedding in EMBEDDING_TYPES :
61- vec = bert_vectorizer . BertVectorizer (
72+ vec = BertVectorizer (
6273 pretrained = pretrained ,
6374 sentence_embedding = embedding
6475 )
65- X_embed = vec .fit_transform (X )
76+ X_embed = vec .fit_transform (X , verbose = False )
6677
6778 vec .save_transformed (str (tmpfile ), X_embed )
6879
0 commit comments