Skip to content

Commit 6b9dcf7

Browse files
committed
complete phi4-mm vision conversion
1 parent fd1d97f commit 6b9dcf7

File tree

6 files changed

+518
-280
lines changed

6 files changed

+518
-280
lines changed

.github/workflows/models448.yml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
name: MODELS - 4.48.3
2+
3+
on:
4+
push:
5+
pull_request:
6+
types:
7+
- closed
8+
branches:
9+
- main
10+
11+
jobs:
12+
run:
13+
name: to-${{ matrix.torch }}-tr-${{ matrix.transformers }}-ci ${{ matrix.os }}-${{ matrix.python }}
14+
runs-on: ${{ matrix.os }}
15+
strategy:
16+
fail-fast: false
17+
matrix:
18+
os: [ubuntu-latest]
19+
python: ['3.12']
20+
transformers: ['4.48.3']
21+
torch: ['main']
22+
steps:
23+
- uses: actions/checkout@v3
24+
25+
- uses: actions/setup-python@v4
26+
with:
27+
python-version: ${{ matrix.python }}
28+
29+
- name: Install pytorch ${{ matrix.torch }}
30+
run: |
31+
if [[ "${{ matrix.torch }}" == "main" ]]; then
32+
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
33+
else
34+
echo "install torch==${{ matrix.torch }} torchvision torchaudio"
35+
pip install torch==${{ matrix.torch }} torchvision torchaudio
36+
fi
37+
38+
- name: Install transformers ${{ matrix.transformers }}
39+
run: |
40+
if [[ "${{ matrix.transformers }}" == "main" ]]; then
41+
echo "install transformers from github"
42+
git clone https://github.com/huggingface/transformers.git
43+
cd transformers
44+
pip install -e .
45+
cd ..
46+
else
47+
echo "install transformers==${{ matrix.transformers }}"
48+
pip install transformers==${{ matrix.transformers }}
49+
fi
50+
51+
- name: Install requirements
52+
run: python -m pip install -r requirements.txt
53+
54+
- name: Install requirements dev
55+
run: python -m pip install -r requirements-dev.txt
56+
57+
- name: Uninstall onnx-diagnostic
58+
run: python -m pip uninstall -y onnx-diagnostic
59+
60+
- name: pip freeze
61+
run: python -m pip freeze
62+
63+
- name: Phi-4-multimodal-instruct - vision
64+
run: |
65+
PYTHONPATH=. python -m onnx_diagnostic.ci_models.export_phi4_mm -m microsoft/Phi-4-multimodal-instruct --device cpu --dtype float16 --exporter custom --no-pretrained --no-second-input --atol 2 --part vision

_unittests/ut_torch_export_patches/test_patch_loops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,18 @@ def scan_filter_position_ids(
6565
):
6666

6767
def body(p_attn_mask, position_ids_row):
68-
h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum()
69-
w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum()
68+
h_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[:, 0].sum()
69+
w_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[0].sum()
7070
torch._check(h_len.item() > 0)
7171
fractional_coords_h = torch.arange(
72-
torch.tensor(0.0, dtype=p_attn_mask.dtype),
73-
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
72+
torch.tensor(0.0, dtype=boundaries.dtype),
73+
torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
7474
h_len,
7575
)
7676
torch._check(w_len.item() > 0)
7777
fractional_coords_w = torch.arange(
78-
torch.tensor(0.0, dtype=p_attn_mask.dtype),
79-
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
78+
torch.tensor(0.0, dtype=boundaries.dtype),
79+
torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
8080
w_len,
8181
)
8282

onnx_diagnostic/ci_models/ci_helpers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
import subprocess
5-
from argparse import ArgumentParser, BooleanOptionalAction
5+
from argparse import ArgumentParser, BooleanOptionalAction, RawTextHelpFormatter
66
from typing import Any, Dict, List, Tuple
77
import onnx
88

@@ -50,10 +50,13 @@ def get_torch_dtype_from_command_line_args(dtype: str) -> "torch.dtype": # noqa
5050
return torch_dtype[dtype]
5151

5252

53-
def get_parser(name: str) -> ArgumentParser:
53+
def get_parser(name: str, epilog: str = "") -> ArgumentParser:
5454
"""Creates a default parser for many models."""
5555
parser = ArgumentParser(
56-
prog=name, description=f"""Export command line for model {name!r}."""
56+
prog=name,
57+
description=f"""Export command line for model {name!r}.""",
58+
epilog=epilog,
59+
formatter_class=RawTextHelpFormatter,
5760
)
5861
parser.add_argument(
5962
"-m",
@@ -110,7 +113,7 @@ def get_parser(name: str) -> ArgumentParser:
110113
"-a",
111114
"--atol",
112115
type=float,
113-
default=1.0,
116+
default=2.0,
114117
help="fails if the maximum discrepancy is above that threshold",
115118
)
116119
parser.add_argument(
@@ -311,7 +314,8 @@ def fprint(s):
311314
diff = max_diff(flat_export_expected, small, hist=[0.1, 0.01])
312315
fprint(f"-- discrepancies={diff}")
313316
assert diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01, (
314-
f"absolution tolerance is above {atol} or number of mismatches is above "
317+
f"absolution tolerance {diff['abs']} is above {atol} or number of "
318+
f"mismatches ({diff['rep']['>0.1'] / diff['n']}) is above "
315319
f"{mismatch01}, dicrepancies={string_diff(diff)}"
316320
)
317321

@@ -362,8 +366,9 @@ def fprint(s):
362366
assert (
363367
diff["abs"] <= atol and diff["rep"][">0.1"] / diff["n"] <= mismatch01
364368
), (
365-
f"absolution tolerance is above {atol} or number of mismatches is "
366-
f"above {mismatch01}, dicrepancies={string_diff(diff)}"
369+
f"absolution tolerance {diff['abs']} is above {atol} or number "
370+
f" of mismatches ({diff['rep']['>0.1'] / diff['n']}) "
371+
f"is above {mismatch01}, dicrepancies={string_diff(diff)}"
367372
)
368373
js = string_diff(diff, js=True, ratio=True, inputs=se, **info)
369374
fs.write(js)

0 commit comments

Comments
 (0)