From a4ad8903d6c8ab23bcf7ee444863603ee488c94a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 11:07:33 +0100 Subject: [PATCH 1/3] Fix extract_sub_model --- .../data/test_sbs_mha_split_every_piece.onnx | Bin 0 -> 10481 bytes .../test_sbs_mha_split_every_piece.onnx.data | Bin 0 -> 3072 bytes _unittests/ut_helpers/test_onnx_helper.py | 40 ++++++++- _unittests/ut_torch_onnx/test_sbs.py | 76 ++++++++++++++++++ onnx_diagnostic/ext_test_case.py | 7 ++ onnx_diagnostic/helpers/onnx_helper.py | 16 ++-- onnx_diagnostic/torch_onnx/sbs.py | 9 ++- onnx_diagnostic/torch_onnx/sbs_dataclasses.py | 15 +++- 8 files changed, 150 insertions(+), 13 deletions(-) create mode 100644 _unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx create mode 100644 _unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data diff --git a/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx new file mode 100644 index 0000000000000000000000000000000000000000..86d72d011f3eb93ff1cd7cd76b1e82dd1b2f0a1b GIT binary patch literal 10481 zcmdT~U2I!P6^>muai>i;xm!x=-BLH%DseVH{!6m!CLwW}wlQr};*_#ftIM@LNp5q0 zoO`bmr=_Bcv`9Rl6-Y=3Bvb?sXjKH#vLK+bU>{bzAP_=;#HvC_2n0erECL=D!I}H# zxj%NCMyj@Hn(I4r&Y78S&N*|=8HdA}S5xTi018x5cbVN{H`p!AudZpC%yJ^mbh5Hk zWkpUCD*{`M99%4`SnF-S%xRY^Lh=KB3vyxv4KF7$>cm7p>m7q-~e6 z%5T*=m$ytL(zT^29Lye3|I#e^l%%zN9t zEu4lWRxNS@F0mz9W6O$MsuVSrNRl*OmU<5m$UDhA?Md$y>N6D7C%o|KP!Gj+hxx}r z)L^lwi{LTAs%bV3|I%!ldB(d~P2nW9EUP#&>FvxDRmgHGF3p#C?HsRaoK(b1Sl@OB zvb`NS0Py12XoT9-nwlhB2(V()?G~U3H33Q`E52azm%!AV^5IzrHJ`$I?c~=BjbB9 zzoggv-l&^j(nr(Ub(>2@o^M@&QgsEI^7?{wHH{3J^jy+I>DUu?pqa=phwGM?j2x}M zz!8<&<%~)=;YsW9WE$0F*!z{%LovXpE%*2iwEIES+|e*EE?Y-IHna-jtj);6wkyUv`+gpw^L4{P$TTvq&FOeA0OnvC)+!;FF?D3N5 z0apyY7DSEVVNUDFky(qJoy|EB?A=wx#(F#QhJ?XmD8yBHwda4yZer7HA_iJGKZr~~9qx>DWFda4T(We~1HBFc+Uz!w^=e5j^loVE!-cqB zjTsTSb<9d~36CU3^CBk*eSHl&27eI@cZIt^mbYA0co?{#aOJgZg%?U#8L@z3S7f2Y zORFO*iY$(jmjSxe73=GxUq{bvOPnY#NKnElD_9#T3tSPe0T+l&j9wieB_*Rs7JJOY z-AqGvC>J?W4`K#1tZJ-UP+4(}W7V?2Yb@Tx$~If(aS@Nml2jcjaT<5?Flv`qR-kgn z1mZW3Lg@-itJ)e9O2m`ta$xXR4n!T<`W_LrcuhpD*A@_6e-uP5`TH2|5fKFtO+dg_BK(8@#PMIJXFs_xeX)OH>YE?={m$CGiD~IZ$Mo_$cc;Gho!RN-*GFsrdL=Xc zmzSQcedr(8rwg%GoEoW<$LfO{C9d4e{a23G?p`{v^FnTWr+hXt)$^$%)1iZ3**Wrs zv$a3G_4%E*-?&wKc#o+K98`9G^v+AQxBfX>`|hXTos4p0wclLdZfS61ENXBg!OzVc zJ)N)J`}(0;X{@8xRuVF=2XEHC_}6Id*PWT^TYoN1Klph2bot@o+TDTAP5$x+uT4hJ z{Pw}`=Ds$yGCfec|K69UT3NXtUDgacAwLd!}8E8u~6 zi4-5r_XuKem0UV-kQl9z(cr)`T)c7yb*)sSA~`i!xlA4^(|z7-Su`33VoDb*23|z| zmL>FQbX0=sGOzKR0IiD31YT55pkoqiOvu@lg05~A=$tH}R4SQPPN5I#T|SJVRD2?- zbfOMP4*^srm@r2n5rN9e+8S2iQbgo86fS9`7pWje4azC|XgmC0z zzEU8MbNnWG%va0gVP~Y|WsXyIra4)H!<)3^H!2w4!sH{=UCHB`q}JQuZi=)m;UW(g zL2xV7e`hFKHg1*FsT)~NE3RFU6^N9~8Jksxf!I*g+X-`losw25A~K^(+%{GmQb?>9>gyna#h zX^g%^d7_*}&+gZy7&hI6pz%>p0Iv)3QHgq|A>U8i8qyfp_%6zjMhxycSP*ups#{{i zZexUH8IX1lQ_?SRy&6%Mo##Y6w+4xtf+aV%&a9K_f1JADR3D0UN!H-*eO@f!63OWnyd&EE9`A4Wcog66F~#jlU^Dnl z%&*hNS)VnWwHXdS?il26r0sI$5m8LxC~KJ(@xlEr(W6o zvYO|mIj*=yE4S5qkrP%L8{5+5ymffEHj@S_#Jxb9l7sjTGGrRI_6Meu7YJT7^okrv5Zo@Jdc zHx~s2t&X}2scU4zAw&6&jxy{`q4wLlzZ;0X98>1dNw&nRMH$L$@I5V+d{c=n>Zb!l z*qX}p-tNb6GoryZ_GhBEp}0!E#sY-LxlOA3u(h^SBKvD8#H*d3Rd_PWu sq&|+9CY6bj8opb4ie8-oStAldx=mBfWb_yzk!jEY)pH9nn9iO516f2*761SM literal 0 HcmV?d00001 diff --git a/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data b/_unittests/ut_helpers/data/test_sbs_mha_split_every_piece.onnx.data new file mode 100644 index 0000000000000000000000000000000000000000..853755460e5c6dbb0804db2f1bb09e3c8f8190df GIT binary patch literal 3072 zcmWNTe>~NP_s4I3L{do-imo3agpkBNrz<~`NeCs93Q3AdcQW0~WMyXR)90?4@2!tY z9`ZF)Uq9aGWKxe?S=l~Ts>|-j%*^DYQrc$q_5XRi&g=1bo%1|Lm3$V$xsc_@%-em} z97c>XMcFkKxHinjX*s3UQ_{ zwnx#4VZq1V02HkkL7Z?U%dRJ#!Xmz4hmg`bxis26L`yI)dVaeJIaylNdSZ zK;YIxcyceWDE@GpNP5&~WUV_!zVRbz+)*r?@eq1xDM{0Bwr9BRL&=~TLRG=k@nS&I zxM3X&I=iuZ&R1Zz-~h-vtxzL9A?Z13$g1CUsdv*-JjDHyoRrPj=XR4yi-w_aB%M31 zkGG4v`X&0Ry)Z7yMX6i5h8mm*-B-spY;+C*Q&MJ>Z_7M=ePPRH3;6s_0^p>MxcA#>8{W zZ#z-6eGv;?{+#l|Uew%*XRFscQMT5M3k$a4@c+1>j$JIu%O6ngrOzPd`)YV0(&g~e zL8$RGKm{scG$07eQU;-*b1%w1T}p;Fp6FN@gALmjqx8;vlsRVM@Rn#inVZNF(*|gG z13B+YOBCr&3r_~zQU1q-M3pcELjG;I?qrLj4cT0p@fbAOktFx=X`jB5$qLJxq`0}0 zRD0b?di{30>dpWu&ZpBk$l%BhZ?3I$Ky{Cp)$$`8I<|o%y>1x#%L+&;aKIk3By@L5 zX8F`2GQO68(#tDQ@va}_e|vJKMJOsWodwC?KZ5)8aT?mYo*&jlW72#F?!20YnuQCn zXrY)7J^i=fzW6CQK6j_Q*EX!F1-`v`C1%N;u(tAZD*kURIcm@1q;Jwt>6pbnda2y| za}}wrMrrc$Zd8stkR*I9=RbAeJ-KfoLKTcg?lCOndb3Z~F`npoP9r9lp!}sL*8loF zb*x!R@tyAA`_hC5eYUX5A`=vevnYea==h&-NM7nG#E4E){p3fmXW0&}a5SUfVk54+ zW`jc4Hq5&k!x!FKa`)SpG~_r4;>m1mF1kWP&ogmk{U+S5?~F})8+lMufMePoSR&el zn(#!9dA1DQ7yc-5Ozk0ik4>z2^%3NChj5WW3O65qNaL%6u<`N^$hRrr$R%miIm4EY z(N+qNp5d3xQK;WrM(U?VJY-u2It>9>Jm(cPj4#3%YbDrkLL9m{0weJe7+dfm9?A-U zyz&ID3HlZ4s~7Tw{`*vS!-=JON@4cGy=cF{79SRbVE5`mRKLq*ySUw0@zWp}yr08T zkvBUnk+4_6Y;J5dq`YD?RH<{IFn9wu_3yw+XJ71YjY58&$jSmOH2mLMP;EIu(sirJ z3yo3KSKRjGd;p4;zY~!gOPLLBvX9AhT&;guARyp7l@(LYoPRU7HC}k1@+)QcD)$EBJWTx_^t_R zHma%To*|kZ^u*HFo7u7U6`96t=lZcaFtNV_cG^dQ``R6~kcMMFm`}}bNa%bkHBEJn(a~+^!rwdQSjX|IO|3UijDcbO7 z97`XqAR|u+h5vR2d^bj->p#6YaF{@r5QAqlUqgoOEDV1?8v}dNF~UEVFMP9w-`=!D z@ij9+de#em7;xez=khr6$UUJn#vAMcd{}ynFfUO z+H6wl$;qpypmE*}A$iwfQht&s6x|JHp?wWc-1Oo`JxlIsIxc8pYb7t7v(a&h1Gzp) z!N9w6(wx*1G$s8oF=Y*PpCh&$DS*LGJ@C(u4Y4-$D3s|QW3$d25SH0va*91vs#as~ ztE*5Ebr*cRmcwM;BWf;Pk2QBRq#WNz;o-kQNoG8zelKOMC|w+Ni^1e$rYxItnfC1d z4k}VFLd)C|C_6PuJrB0PeO+^wzf`wJ`hOu&XPpIU$O%8i&sRyj;tgc}YaEpGizUT# z=X0ZN0c2+DVV`*vUV3l@yPcg`-Chhi+A%nG=?vArz6Gk10IsoK!LF8RC|~S_=jz>f zBs?ECTwlw**DSHOX91S43goemjte>*RbQ>*DJ-Y(lSt%U+UKHmj z{aOC^d~{j14m&r7aYThHw>ZpYnL>-if2{?xGG~r_5e>DyUr~?oW6Jns4LW^mh?23H z{nj3V_#hvUKAq7uu*HchkA>NbjL`9f6Uaxy{LXoJW4*N<3Okh0X#Rtw{z@|${nH$( zu4eGi;1Vn@hRZk&a+mF}*AV3!64_eApnI zjI%+Jk!zde*#W%Jvx^Ok3$XceIx5A{tiF}b$#1O%qmoM!@wN}CG_Dmo&UZuP=r*cR zouXu4E9}_k3w=9-$s~F$n=aYIjq4V(eeQcRz1pd)wFKOM38vamE9iW*hC>r(YByPc za@QtG-=Bc)Q)|KEAMc^gUe{tkGhfUu@gz7KlU34 z%VM?cA$4?RLkZ^Kp*K0`UY7v!%7v0h{~$PYKAyLS=<)renZHXE3Zk=r`&L<4v#4`l z`+$QDT6{2@Ll3_Ok{^ox;8tUboiD^ zx9_La+;mhl&2aPAi&Pr%AxhqS0Zq}dsJx|5YHkCWy$$WS)k;ob4p>@ei#l79*;w%* z27kAUBMjj8|u literal 0 HcmV?d00001 diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 69aff474..9a9242e8 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -1,6 +1,8 @@ +import os import unittest from typing import Any, Dict, List import numpy as np +import onnx import onnx.helper as oh import onnx.numpy_helper as onh from onnx import TensorProto, FunctionProto, ValueInfoProto @@ -475,7 +477,7 @@ def _mkv_(name): def test_onnx_dtype_name(self): for k in dir(TensorProto): - if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL"}: + if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}: self.assertEqual(k, onnx_dtype_name(getattr(TensorProto, k))) self.assertRaise(lambda: onnx_dtype_name(1000), ValueError) self.assertEqual(onnx_dtype_name(1000, exc=False), "UNEXPECTED") @@ -532,6 +534,42 @@ def _type_rank_fn(name): check_model(new_model) self.check_ort(new_model) + def test_extract_subset_of_nodes_bigger(self): + model = onnx.load( + os.path.join( + os.path.dirname(__file__), "data", "test_sbs_mha_split_every_piece.onnx" + ) + ) + nodes = extract_subset_of_nodes( + model=model, + name="scaled_dot_product_attention", + node_index=16, + cut_points={ + "linear", + "linear_1", + "linear_2", + "output_0", + "scaled_dot_product_attention", + "transpose_2", + "view_2", + "x", + }, + ) + self.assertEqual( + [ + "Mul", + "Reshape", + "Transpose", + "Mul", + "Reshape", + "Transpose", + "FusedMatMul", + "Softmax", + "MatMul", + ], + [n.op_type for n in nodes], + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index c6381769..450ce6d6 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -1,3 +1,4 @@ +import os import unittest import pandas import onnx @@ -777,6 +778,81 @@ def forward(self, query, key, value, seq_lens): df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx")) # self.clean_dump() + @hide_stdout() + @ignore_warnings((DeprecationWarning, FutureWarning, UserWarning)) + def test_sbs_mha_split_every_piece(self): + torch = self.torch + + class Model(self.torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int): + super(Model, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + assert embed_dim % num_heads == 0, ( + f"embed_dim % num_heads =! 0 -> " + f"{embed_dim} % {num_heads} = {embed_dim % num_heads}" + ) + + self.W_q = torch.nn.Linear(embed_dim, embed_dim) + self.W_k = torch.nn.Linear(embed_dim, embed_dim) + self.W_v = torch.nn.Linear(embed_dim, embed_dim) + + def split_heads(self, t, seq_len): + return t.view(t.shape[0], seq_len, self.num_heads, self.head_dim).transpose( + 1, 2 + ) + + def forward(self, x): + q = self.split_heads(self.W_q(x), x.shape[1]) + k = self.split_heads(self.W_k(x), x.shape[1]) + v = self.split_heads(self.W_v(x), x.shape[1]) + return ( + torch.nn.functional.scaled_dot_product_attention(q, k, v) + .transpose(1, 2) + .reshape(x.shape[0], x.shape[1], self.embed_dim) + ) + + embed_dim = 16 + num_heads = 4 + seq_len = 10 + batch_size = 2 + inputs = dict(x=torch.randn(batch_size, seq_len, embed_dim)) + model = Model(embed_dim, num_heads) + model(**inputs) + ds = dict(x={0: "batch", 1: "seqlen"}) + + ep = self.torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + self.dump_text("test_sbs_mha_split_every_piece.ep", str(ep)) + filename = self.get_dump_file("test_sbs_mha_split_every_piece.onnx") + to_onnx(ep, exporter="custom", filename=filename) + replay = self.get_dump_folder("test_sbs_mha_split_every_piece_replay") + onx = onnx.load(filename) + results = list( + run_aligned( + ep, + onx, + kwargs=inputs, + run_cls=OnnxruntimeEvaluator, + verbose=11, + use_tensor=True, + run_onnx_with_torch_inputs=True, + replay_configuration=ReplayConfiguration( + dump_folder=replay, selected_op_types={"MatMul"}, threshold=2**20 + ), + ), + ) + df = pandas.DataFrame(list(results)).dropna(axis=1, how="all") + df.to_excel(self.get_dump_file("test_sbs_mha_split_every_piece.xlsx")) + max_abs = df["err_abs"].max() + self.assertLess(max_abs, 1e-5) + # self.clean_dump() + subonnx = onnx.load(os.path.join(replay, "scaled_dot_product_attention", "model.onnx")) + self.assertEqual(len(subonnx.graph.input), 3) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 6c19b409..0da11c02 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -845,6 +845,13 @@ def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str: f.write(proto.SerializeToString()) return fullname + def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str: + """Dumps text in a file.""" + fullname = self.get_dump_file(name, folder=folder) + with open(fullname, "w") as f: + f.write(text) + return fullname + def assertExists(self, name): """Checks the existing of a file.""" if not os.path.exists(name): diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index d12bbd95..734d6a45 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -332,7 +332,7 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str: print(onnx_dtype_name(7)) """ for k in dir(TensorProto): - if k.upper() == k and k != "EXTERNAL": + if k.upper() == k and k not in {"DESCRIPTOR", "EXTERNAL", "DEFAULT"}: v = getattr(TensorProto, k) if v == itype: return k @@ -1219,11 +1219,14 @@ def extract_subset_of_nodes( if name in node.output: node_index = i break - assert ( - node_index is not None - and node_index < len(model.graph.node) - and name in model.graph.node[node_index].output - ), f"node_index is still empty or wrong for result {name!r}" + assert node_index is not None and node_index < len(model.graph.node), ( + f"node_index={node_index} (n_nodes={len(model.graph.node)}) " + f"is still empty or wrong for result {name!r}" + ) + assert name in model.graph.node[node_index].output, ( + f"Unable to find {name!r} in {model.graph.node[node_index].output}, " + f"node={pretty_onnx(model.graph.node[node_index])}" + ) if cut_points is None: cut_points = {n.name for n in model.graph.input} | { n.name for n in model.graph.initializer @@ -1236,6 +1239,7 @@ def extract_subset_of_nodes( current_node_index = node_index current_input_index = 0 intermediate = {name} + cut_points -= {name} inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index 7cac5e79..4bdfed0a 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -381,7 +381,8 @@ def _preparation_with_fx_graph( assert len(torch_input_names) < len(onx.graph.input), ( f"torch_input_names={torch_input_names!r}, " f"onnx_input_names={[n.name for n in onx.graph.input]}, " - f"node.name={node.name!r} cannot be an input" + f"node.name={node.name!r} cannot be an input, " + f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}" ) assert node.name not in skip_mapping_torch_onnx, ( f"{node.name!r} is ambiguous, cannot be mapped due to " @@ -772,9 +773,9 @@ def forward(self, x): # preparation with ep.graph.nodes ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)} placeholders_to_state_dict = { - **{f"p_{name.replace('.', '_')}": name for name in ep.state_dict}, - **{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()}, - **{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants}, + **{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict}, + **{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()}, + **{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants}, } skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict) placeholders = {} diff --git a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py index 234e424d..2b7255d2 100644 --- a/onnx_diagnostic/torch_onnx/sbs_dataclasses.py +++ b/onnx_diagnostic/torch_onnx/sbs_dataclasses.py @@ -243,7 +243,17 @@ def dump( :return: the folder created to dump everything """ if verbose: - print(f"[ReplayConfiguration.dump] extract subset of node for {name!r}") + print( + f"[ReplayConfiguration.dump] extract subset of nodes for " + f"{name!r} (onnx_id_node={onnx_id_node})" + ) + if verbose >= 10: + print(f"[ReplayConfiguration.dump] onnx_results={sorted(onnx_results)}") + print(f"[ReplayConfiguration.dump] torch_results={sorted(torch_results)}") + print( + f"[ReplayConfiguration.dump] onnx_name_to_ep_name=" + f"{sorted(onnx_name_to_ep_name)}" + ) nodes = extract_subset_of_nodes( model=model, name=name, @@ -253,7 +263,8 @@ def dump( if not nodes: if verbose: print( - f"[ReplayConfiguration.dump] could not extract subset of node for {name!r}" + f"[ReplayConfiguration.dump] could not extract subset of " + f"nodes for {name!r}" ) return None if verbose: From 3a3a7b0981f90999e9a020e21bdd666882419379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 11:37:15 +0100 Subject: [PATCH 2/3] changelogs --- .github/workflows/ci.yml | 19 ++++++++++--------- CHANGELOGS.rst | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c40ab2df..de947320 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,15 +106,16 @@ jobs: pip install torch==${{ matrix.torch }} torchvision torchaudio fi - - name: Cache pip - if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }} - uses: actions/cache@v4 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - ${{ runner.os }}- + # two slow + #- name: Cache pip + # if: ${{ matrix.torch != 'main' && matrix.transformers != 'main' }} + # uses: actions/cache@v4 + # with: + # path: ~/.cache/pip + # key: ${{ runner.os }}-pip-${{ hashFiles('requirements-dev.txt') }} + # restore-keys: | + # ${{ runner.os }}-pip- + # ${{ runner.os }}- - name: pip freeze run: python -m pip freeze diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ee0b1a3e..30fd0e48 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.4 +++++ +* :pr:`337`: fixes extract_subset_of_nodes * :pr:`336`: implements versioned onnx plugs 0.8.3 From 3c9d8675bfb19054b718eb9e6a45f5015f52255d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 3 Dec 2025 11:02:38 +0000 Subject: [PATCH 3/3] handle empty input --- onnx_diagnostic/helpers/onnx_helper.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 734d6a45..85713459 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1243,13 +1243,22 @@ def extract_subset_of_nodes( inputs = set(k for k in node.input if k) while not (inputs <= cut_points) and current_node_index >= 0: node = model.graph.node[current_node_index] - if current_input_index == 0: + if current_input_index == 0 or not node.input: needs = [o for o in node.output if o in intermediate and o not in cut_points] if needs: selected.add(current_node_index) + if not node.input: + current_node_index -= 1 + current_input_index = 0 + continue else: current_node_index -= 1 + current_input_index = 0 continue + assert current_input_index < len(node.input), ( + f"current_input_index={current_input_index} but node.input={node.input}, " + f"node={pretty_onnx(node)}" + ) res = node.input[current_input_index] if res not in cut_points: intermediate.add(res) @@ -1294,8 +1303,8 @@ def _mkv_(name, itype, irank): oh.make_graph( nodes, "submodel", - [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)], - [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)], + [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n], + [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n], ), ir_version=ir_version, opset_imports=opset_imports,