Skip to content

Commit dc00d07

Browse files
committed
Add support for sentence similarity
1 parent 035ccf8 commit dc00d07

17 files changed

+183
-20
lines changed

_doc/api/tasks/automatic_speech_recognition.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.automatic_speech_recognition
3-
===================================================
2+
onnx_diagnostic.tasks.automatic_speech_recognition
3+
==================================================
44

55
.. automodule:: onnx_diagnostic.tasks.automatic_speech_recognition
66
:members:

_doc/api/tasks/fill_mask.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.fill_mask
3-
================================
2+
onnx_diagnostic.tasks.fill_mask
3+
===============================
44

55
.. automodule:: onnx_diagnostic.tasks.fill_mask
66
:members:

_doc/api/tasks/image_classification.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.image_classification
3-
===========================================
2+
onnx_diagnostic.tasks.image_classification
3+
==========================================
44

55
.. automodule:: onnx_diagnostic.tasks.image_classification
66
:members:

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ onnx_diagnostic.tasks
99
fill_mask
1010
image_classification
1111
image_text_to_text
12+
sentence_similarity
1213
text_classification
1314
text_generation
1415
text2text_generation
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.sentence_similarity
3+
=========================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.sentence_similarity
6+
:members:
7+
:no-undoc-members:

_doc/api/tasks/text2text_generation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.text2text_generation
3-
===========================================
2+
onnx_diagnostic.tasks.text2text_generation
3+
==========================================
44

55
.. automodule:: onnx_diagnostic.tasks.text2text_generation
66
:members:

_doc/api/tasks/text_classification.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.text_classification
3-
==========================================
2+
onnx_diagnostic.tasks.text_classification
3+
=========================================
44

55
.. automodule:: onnx_diagnostic.tasks.text_classification
66
:members:

_doc/api/tasks/text_generation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.text_generation
3-
======================================
2+
onnx_diagnostic.tasks.text_generation
3+
=====================================
44

55
.. automodule:: onnx_diagnostic.tasks.text_generation
66
:members:

_doc/api/tasks/zero_shot_image_classification.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
onnx_diagnostic.export.zero_shot_image_classification
3-
=====================================================
2+
onnx_diagnostic.tasks.zero_shot_image_classification
3+
====================================================
44

55
.. automodule:: onnx_diagnostic.tasks.zero_shot_image_classification
66
:members:

_unittests/ut_tasks/test_tasks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ class TestTasks(ExtTestCase):
99
@hide_stdout()
1010
def test_text2text_generation(self):
1111
mid = "sshleifer/tiny-marian-en-de"
12-
# mid = "Salesforce/codet5-small"
1312
data = get_untrained_model_with_inputs(mid, verbose=1)
1413
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
1514
model, inputs = data["model"], data["inputs"]
@@ -85,7 +84,6 @@ def test_automatic_speech_recognition(self):
8584
@hide_stdout()
8685
def test_imagetext2text_generation(self):
8786
mid = "HuggingFaceM4/tiny-random-idefics"
88-
# mid = "Salesforce/codet5-small"
8987
data = get_untrained_model_with_inputs(mid, verbose=1)
9088
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
9189
model, inputs = data["model"], data["inputs"]
@@ -94,7 +92,6 @@ def test_imagetext2text_generation(self):
9492
@hide_stdout()
9593
def test_fill_mask(self):
9694
mid = "google-bert/bert-base-multilingual-cased"
97-
# mid = "Salesforce/codet5-small"
9895
data = get_untrained_model_with_inputs(mid, verbose=1)
9996
self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)])
10097
model, inputs = data["model"], data["inputs"]
@@ -103,12 +100,19 @@ def test_fill_mask(self):
103100
@hide_stdout()
104101
def test_text_classification(self):
105102
mid = "Intel/bert-base-uncased-mrpc"
106-
# mid = "Salesforce/codet5-small"
107103
data = get_untrained_model_with_inputs(mid, verbose=1)
108104
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
109105
model, inputs = data["model"], data["inputs"]
110106
model(**inputs)
111107

108+
@hide_stdout()
109+
def test_sentence_similary(self):
110+
mid = "sentence-transformers/all-MiniLM-L6-v1"
111+
data = get_untrained_model_with_inputs(mid, verbose=1)
112+
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
113+
model, inputs = data["model"], data["inputs"]
114+
model(**inputs)
115+
112116

113117
if __name__ == "__main__":
114118
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)