Skip to content

Commit 345b783

Browse files
authored
improves documentation, introduce mixture of expert as a task, add python 3.10 (#71)
* improves documentation * first step for moe * doc * disable * add 3.10 * fix ut * exclude * ci * update ci * add feature extraction
1 parent 94bae15 commit 345b783

File tree

18 files changed

+329
-14
lines changed

18 files changed

+329
-14
lines changed

.github/workflows/ci.yml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,20 @@ jobs:
1515
strategy:
1616
matrix:
1717
os: [ubuntu-latest]
18-
python: ['3.11', '3.12']
18+
python: ['3.10', '3.11', '3.12']
1919
transformers: ['4.48.3', '4.51.3', 'main']
20-
torch: ['2.6', 'main']
21-
20+
torch: ['2.6', '2.7', 'main']
21+
exclude:
22+
- python: '3.10'
23+
transformers: 'main'
24+
- python: '3.10'
25+
torch: '2.7'
26+
- python: '3.11'
27+
transformers: '4.51.3'
28+
- python: '3.11'
29+
torch: '2.7'
30+
- python: '3.12'
31+
torch: '2.6'
2232
steps:
2333
- uses: actions/checkout@v3
2434

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.feature_extraction
3+
========================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.feature_extraction
6+
:members:
7+
:no-undoc-members:

_doc/api/tasks/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ Or:
3434

3535
automatic_speech_recognition
3636
fill_mask
37+
feature_extraction
3738
image_classification
3839
image_text_to_text
40+
mixture_of_expert
3941
sentence_similarity
4042
text_classification
4143
text_generation
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.mixture_of_expert
3+
=======================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.mixture_of_expert
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"sphinx.ext.githubpages",
1313
"sphinx.ext.ifconfig",
1414
"sphinx.ext.intersphinx",
15+
"sphinx.ext.linkcode",
1516
"sphinx.ext.mathjax",
1617
"sphinx.ext.viewcode",
1718
"sphinx.ext.todo",
@@ -63,15 +64,20 @@
6364
# ]
6465

6566
# The following is used by sphinx.ext.linkcode to provide links to github
66-
linkcode_resolve = make_linkcode_resolve(
67-
"onnx-diagnostic",
67+
_linkcode_resolve = make_linkcode_resolve(
68+
"onnx_diagnostic",
6869
(
6970
"https://github.com/sdpython/onnx-diagnostic/"
7071
"blob/{revision}/{package}/"
7172
"{path}#L{lineno}"
7273
),
7374
)
7475

76+
77+
def linkcode_resolve(domain, info):
78+
return _linkcode_resolve(domain, info)
79+
80+
7581
latex_elements = {
7682
"papersize": "a4",
7783
"pointsize": "10pt",

_doc/recipes/plot_dynamic_shapes_max.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
in the exported program is something very aggreessive. Here is a case where
1111
it takes a wrong decision and how to get around it.
1212
13+
**This bug was fixed after 4/24/2025**.
14+
1315
Wrong Model
1416
+++++++++++
1517
"""
@@ -183,4 +185,4 @@ def forward(self, x, y, fact):
183185
# is hidden in a custom operator.
184186

185187

186-
doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow")
188+
doc.plot_legend("max(d1, d2)\nwith d1, d2 dimensions", "dynamic shapes", "green")

_unittests/ut_tasks/test_tasks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def test_fill_mask(self):
116116
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
117117
)
118118

119+
@hide_stdout()
120+
def test_feature_extraction(self):
121+
mid = "facebook/bart-base"
122+
data = get_untrained_model_with_inputs(mid, verbose=1)
123+
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
124+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
125+
model(**inputs)
126+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
127+
torch.export.export(
128+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
129+
)
130+
119131
@hide_stdout()
120132
def test_text_classification(self):
121133
mid = "Intel/bert-base-uncased-mrpc"

_unittests/ut_tasks/try_tasks.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_text2text_generation(self):
9999
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
100100

101101
@never_test()
102-
def test_text_generation_phi4(self):
103-
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4
102+
def test_text_generation_phi4_mini(self):
103+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini
104104

105105
import torch
106106
from transformers import RobertaTokenizer, T5ForConditionalGeneration
@@ -124,6 +124,107 @@ def test_text_generation_phi4(self):
124124
)
125125
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
126126

127+
@never_test()
128+
@unittest.skip(
129+
reason="AttributeError: 'Phi4MMModel' object has no attribute "
130+
"'prepare_inputs_for_generation'"
131+
)
132+
def test_text_generation_phi4_moe(self):
133+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_moe
134+
135+
import requests
136+
import io
137+
from PIL import Image
138+
import soundfile as sf
139+
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
140+
from urllib.request import urlopen
141+
142+
# Define model path
143+
model_path = "microsoft/Phi-4-multimodal-instruct"
144+
145+
# Load model and processor
146+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
147+
model = AutoModelForCausalLM.from_pretrained(
148+
model_path,
149+
device_map="cuda",
150+
torch_dtype="auto",
151+
trust_remote_code=True,
152+
# if you do not use Ampere or later GPUs, change attention to "eager"
153+
# _attn_implementation='flash_attention_2',
154+
_attn_implementation="eager",
155+
).cuda()
156+
157+
# Load generation config
158+
generation_config = GenerationConfig.from_pretrained(model_path)
159+
160+
# Define prompt structure
161+
user_prompt = "<|user|>"
162+
assistant_prompt = "<|assistant|>"
163+
prompt_suffix = "<|end|>"
164+
165+
# Part 1: Image Processing
166+
print("\n--- IMAGE PROCESSING ---")
167+
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
168+
prompt = (
169+
f"{user_prompt}<|image_1|>What is shown in this image"
170+
f"?{prompt_suffix}{assistant_prompt}"
171+
)
172+
print(f">>> Prompt\n{prompt}")
173+
174+
# Download and open image
175+
image = Image.open(requests.get(image_url, stream=True).raw)
176+
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda:0")
177+
178+
# Generate response
179+
print("--------- IMAGE PROCESSING ----------")
180+
print()
181+
with steal_forward(model):
182+
generate_ids = model.generate(
183+
**inputs,
184+
max_new_tokens=1000,
185+
generation_config=generation_config,
186+
)
187+
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
188+
response = processor.batch_decode(
189+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
190+
)[0]
191+
print(f">>> Response\n{response}")
192+
193+
# Part 2: Audio Processing
194+
print("\n--- AUDIO PROCESSING ---")
195+
audio_url = (
196+
"https://upload.wikimedia.org/wikipedia/commons/b/b0/"
197+
"Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
198+
)
199+
speech_prompt = (
200+
"Transcribe the audio to text, and then translate the audio to French. "
201+
"Use <sep> as a separator between the original transcript and the translation."
202+
)
203+
prompt = f"{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}"
204+
print(f">>> Prompt\n{prompt}")
205+
206+
# Download and open audio file
207+
audio, samplerate = sf.read(io.BytesIO(urlopen(audio_url).read()))
208+
209+
# Process with the model
210+
inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors="pt").to(
211+
"cuda:0"
212+
)
213+
214+
print("--------- AUDIO PROCESSING ----------")
215+
print()
216+
with steal_forward(model):
217+
generate_ids = model.generate(
218+
**inputs,
219+
max_new_tokens=1000,
220+
generation_config=generation_config,
221+
)
222+
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
223+
response = processor.batch_decode(
224+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
225+
)[0]
226+
print(f">>> Response\n{response}")
227+
127228
@never_test()
128229
def test_imagetext2text_generation(self):
129230
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k etext2t
@@ -237,6 +338,22 @@ def test_fill_mask(self):
237338
output = model(**encoded_input)
238339
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
239340

341+
@never_test()
342+
def test_feature_extraction(self):
343+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k feature_ex
344+
# https://huggingface.co/google-bert/bert-base-multilingual-cased
345+
346+
from transformers import BartTokenizer, BartModel
347+
348+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
349+
model = BartModel.from_pretrained("facebook/bart-base")
350+
text = "Replace me by any text you'd like."
351+
encoded_input = tokenizer(text, return_tensors="pt")
352+
print()
353+
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
354+
output = model(**encoded_input)
355+
print("-- outputs", string_type(output, with_shape=True, with_min_max=True))
356+
240357
@never_test()
241358
def test_text_classification(self):
242359
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text_cl

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
175175
self.assertEqualAny([cache], cache2)
176176

177177
@ignore_warnings(UserWarning)
178-
@requires_torch("2.7")
178+
@requires_torch("2.8")
179179
def test_sliding_window_cache_export(self):
180180
class Model(torch.nn.Module):
181181
def forward(self, cache):

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ class TestTestHelper(ExtTestCase):
2222
def test_get_inputs_for_task(self):
2323
fcts = supported_tasks()
2424
for task in self.subloop(sorted(fcts)):
25-
data = get_inputs_for_task(task)
25+
try:
26+
data = get_inputs_for_task(task)
27+
except NotImplementedError:
28+
continue
2629
self.assertIsInstance(data, dict)
2730
self.assertIn("inputs", data)
2831
self.assertIn("dynamic_shapes", data)
@@ -99,9 +102,11 @@ def test_validate_model_custom(self):
99102
patch=True,
100103
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
101104
optimization="default",
105+
quiet=False,
102106
)
103107
self.assertIsInstance(summary, dict)
104108
self.assertIsInstance(data, dict)
109+
self.assertIn("disc_onnx_ort_run_abs", summary)
105110
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
106111
onnx_filename = data["onnx_filename"]
107112
output_path = f"{onnx_filename}.ortopt.onnx"

0 commit comments

Comments
 (0)