Skip to content

Commit 373bcd3

Browse files
committed
Refactors sbs to save memory
1 parent 2db1ba1 commit 373bcd3

File tree

3 files changed

+180
-110
lines changed

3 files changed

+180
-110
lines changed

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(self, x):
6969
verbose=10,
7070
),
7171
)
72-
self.assertEqual(len(results), 7)
72+
self.assertEqual(len(results), 6)
7373

7474
@hide_stdout()
7575
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -104,7 +104,7 @@ def forward(self, x):
104104
verbose=10,
105105
),
106106
)
107-
self.assertEqual(len(results), 6)
107+
self.assertEqual(len(results), 5)
108108

109109
@hide_stdout()
110110
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -136,7 +136,7 @@ def forward(self, x):
136136
verbose=10,
137137
),
138138
)
139-
self.assertEqual(len(results), 6)
139+
self.assertEqual(len(results), 5)
140140

141141
@hide_stdout()
142142
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -167,7 +167,7 @@ def forward(self, x):
167167
verbose=11,
168168
),
169169
)
170-
self.assertEqual(len(results), 7)
170+
self.assertEqual(len(results), 6)
171171

172172
@hide_stdout()
173173
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -199,7 +199,7 @@ def forward(self, x):
199199
use_tensor=True,
200200
),
201201
)
202-
self.assertEqual(len(results), 8)
202+
self.assertEqual(len(results), 7)
203203

204204
@hide_stdout()
205205
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -232,7 +232,7 @@ def forward(self, x):
232232
use_tensor=True,
233233
),
234234
)
235-
self.assertEqual(len(results), 8)
235+
self.assertEqual(len(results), 7)
236236

237237
@hide_stdout()
238238
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -267,7 +267,7 @@ def forward(self, x):
267267
use_tensor=True,
268268
),
269269
)
270-
self.assertEqual(len(results), 14)
270+
self.assertEqual(len(results), 8)
271271

272272
@hide_stdout()
273273
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
@@ -301,10 +301,10 @@ def forward(self, x):
301301
use_tensor=True,
302302
),
303303
)
304-
self.assertEqual(len(results), 14)
304+
self.assertEqual(len(results), 8)
305305
self.assertEqual(
306+
[None, None, 0, 0, 0, 0, 0, 0],
306307
[r.err_dev for r in results],
307-
[None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0, 0],
308308
)
309309

310310
@hide_stdout()
@@ -364,13 +364,13 @@ def forward(self, x):
364364
],
365365
sorted(df.columns),
366366
)
367-
self.assertEqual(len(results), 12)
367+
self.assertEqual(len(results), 8)
368368
self.assertEqual(
369+
[None, None, None, None, None, 0, 0, 0],
369370
[r.err_dev for r in results],
370-
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
371371
)
372372
self.assertEqual(
373-
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
373+
[-10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
374374
df["onnx_id_node"].fillna(-10).tolist(),
375375
)
376376
self.clean_dump()

onnx_diagnostic/helpers/ort_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def run_dlpack(
512512
v = v.to(torch.uint8)
513513
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
514514
else:
515-
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
515+
v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
516516
input_names.append(k)
517517
values.push_back(v)
518518
if self.nvtx:

0 commit comments

Comments
 (0)