Skip to content

Commit 4130fb2

Browse files
committed
Move exported model to cpu before save.
1 parent 653d653 commit 4130fb2

File tree

14 files changed

+82
-17
lines changed

14 files changed

+82
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
- change: Install the TensorRT package for architectures other than x86_64
2929
- change: Disable conversion fallback for TensorRT paths and expose control option in custom config
3030
- fix: Correctness command relative tolerance formula
31+
- fix: Memory management during export and conversion process for Torch
3132

3233
## 0.13.0
3334

model_navigator/commands/data_dump/samples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from model_navigator.core.tensor import TensorMetadata
2626
from model_navigator.core.workspace import Workspace
2727
from model_navigator.frameworks import Framework
28+
from model_navigator.frameworks.memory import offload_model_to_cpu
2829
from model_navigator.runners.base import NavigatorRunner
2930

3031

@@ -165,6 +166,7 @@ class FetchOutputModelData(Command, is_required=True):
165166

166167
def _run(
167168
self,
169+
framework: Framework,
168170
workspace: Workspace,
169171
model: Any,
170172
runner_cls: Type[NavigatorRunner],
@@ -194,6 +196,8 @@ def _run(
194196
output_data_path = workspace.path / "model_output"
195197
output_data_path.mkdir(parents=True, exist_ok=True)
196198

199+
offload_model_to_cpu(model, framework)
200+
197201
runner_kwargs = runner_config.to_dict() if runner_config is not None else {}
198202
runner = runner_cls(
199203
model=model, input_metadata=input_metadata, output_metadata=output_metadata, **runner_kwargs
@@ -211,6 +215,8 @@ def _run(
211215
sample_path = output_data_path / sample_name
212216
samples_to_npz(outputs, sample_path, batch_dim, raise_on_error=raise_on_error, num_samples=len(samples))
213217

218+
runner.deactivate()
219+
214220
return CommandOutput(
215221
status=CommandStatus.OK,
216222
)

model_navigator/commands/export/exporters/torch2dynamo_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def export(
7070
device_max_batch_size: Maximum batch size that fits on the device. Defaults to None.
7171
"""
7272
model = get_model()
73+
model.to(target_device)
7374

7475
if not navigator_workspace:
7576
navigator_workspace = pathlib.Path.cwd()

model_navigator/commands/export/exporters/torch2exportedprogram.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from model_navigator.core.dataloader import expand_sample, load_samples
2323
from model_navigator.core.tensor import TensorMetadata
2424
from model_navigator.exceptions import ModelNavigatorRuntimeError
25+
from model_navigator.frameworks.torch.utils import offload_torch_model_to_cpu
2526

2627

2728
def get_model() -> torch.nn.Module:
@@ -133,4 +134,7 @@ def export(
133134
exported_model_path = pathlib.Path(exported_model_path)
134135
if not exported_model_path.is_absolute():
135136
exported_model_path = navigator_workspace / exported_model_path
137+
136138
torch.export.save(exported_model, exported_model_path.as_posix())
139+
140+
offload_torch_model_to_cpu(exported_model.module())

model_navigator/commands/export/exporters/torch2onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def export(
6363
For available arguments check PyTorch documentation: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export
6464
"""
6565
model = get_model()
66-
model = model.to(export_device)
66+
model.to(export_device)
6767

6868
if not navigator_workspace:
6969
navigator_workspace = pathlib.Path.cwd()

model_navigator/commands/export/exporters/torch2quantized_onnx.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from model_navigator.core.dataloader import load_samples
2727
from model_navigator.core.logger import LOGGER
2828
from model_navigator.core.tensor import TensorMetadata
29+
from model_navigator.frameworks.torch.utils import offload_torch_model_to_cpu
2930
from model_navigator.utils.common import numpy_to_torch_dtype
3031

3132

@@ -105,7 +106,7 @@ def export(
105106
model_copy = deepcopy(original_model)
106107

107108
# Offload original model to CPU
108-
original_model.to("cpu")
109+
offload_torch_model_to_cpu(original_model)
109110

110111
try:
111112
# Move model copy to target device
@@ -226,10 +227,8 @@ def forward_loop(model):
226227
LOGGER.info("Quantized ONNX export completed successfully")
227228

228229
# Clean up
229-
del model_copy
230-
del quantized_model
231-
torch.cuda.empty_cache()
232-
230+
offload_torch_model_to_cpu(model_copy)
231+
offload_torch_model_to_cpu(quantized_model)
233232
except Exception as e:
234233
LOGGER.error(f"Error during quantized ONNX export: {str(e)}")
235234
raise

model_navigator/commands/export/exporters/torch2torchscript.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def export(
6161
For available arguments check PyTorch documentation: https://pytorch.org/docs/stable/jit.html#torch.jit.trace
6262
"""
6363
model = get_model()
64+
6465
target_jit_type = JitType(target_jit_type)
6566

6667
if not navigator_workspace:

model_navigator/commands/export/torch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from model_navigator.core.tensor import TensorMetadata
2828
from model_navigator.core.workspace import Workspace
2929
from model_navigator.exceptions import ModelNavigatorConfigurationError
30+
from model_navigator.frameworks.torch.utils import offload_torch_model_to_cpu
3031
from model_navigator.utils.common import parse_kwargs_to_cmd
3132

3233

@@ -93,7 +94,7 @@ def _run(
9394

9495
# Keep model on CPU after operation
9596
def on_exit():
96-
model.to("cpu")
97+
offload_torch_model_to_cpu(model)
9798

9899
with ExecutionContext(
99100
workspace=workspace,
@@ -197,7 +198,7 @@ def _run(
197198

198199
# Keep model on CPU after operation
199200
def on_exit():
200-
model.to("cpu")
201+
offload_torch_model_to_cpu(model)
201202

202203
with ExecutionContext(
203204
workspace=workspace,
@@ -283,7 +284,7 @@ def _run(
283284

284285
# Keep model on CPU after operation
285286
def on_exit():
286-
model.to("cpu")
287+
offload_torch_model_to_cpu(model)
287288

288289
with ExecutionContext(
289290
workspace=workspace,
@@ -375,7 +376,7 @@ def _run(
375376

376377
# Keep model on CPU after operation
377378
def on_exit():
378-
model.to("cpu")
379+
offload_torch_model_to_cpu(model)
379380

380381
if dynamo_dynamic_shapes is None:
381382
dynamic_shapes = batch_dim is not None or dynamic_axes
@@ -478,7 +479,7 @@ def _run(
478479

479480
# Keep model on CPU after operation
480481
def on_exit():
481-
model.to("cpu")
482+
offload_torch_model_to_cpu(model)
482483

483484
with ExecutionContext(
484485
workspace=workspace,

model_navigator/commands/infer_metadata.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from model_navigator.core.workspace import Workspace
3434
from model_navigator.exceptions import ModelNavigatorUserInputError
3535
from model_navigator.frameworks import Framework, is_torch_available
36+
from model_navigator.frameworks.memory import offload_model_to_cpu
3637
from model_navigator.frameworks.onnx.utils import get_onnx_io_names
3738
from model_navigator.frameworks.tensorrt.utils import get_tensorrt_io_names
3839
from model_navigator.runners.base import NavigatorRunner
@@ -290,6 +291,8 @@ def _run(
290291
else:
291292
temp_output_metadata = None
292293

294+
offload_model_to_cpu(model, framework)
295+
293296
runner_kwargs = runner_config.to_dict() if runner_config is not None else {}
294297
runner = runner_cls(
295298
model=model,
@@ -320,6 +323,8 @@ def _run(
320323

321324
output_metadata = _get_metadata_from_axes_shapes(pytree_metadata, axes_shapes, batch_dim, output_dtypes)
322325

326+
runner.deactivate()
327+
323328
return CommandOutput(
324329
status=CommandStatus.OK,
325330
output={"output_metadata": output_metadata},
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Memory management utilities for frameworks."""
15+
16+
from typing import Any
17+
18+
from model_navigator.frameworks import Framework, is_torch_available
19+
20+
21+
def offload_model_to_cpu(model: Any, framework: Framework):
22+
"""Offload model to CPU.
23+
24+
Args:
25+
model: Model to offload.
26+
framework: Framework of model to offload.
27+
"""
28+
if is_torch_available() and framework == Framework.TORCH:
29+
from model_navigator.frameworks.torch.utils import offload_torch_model_to_cpu
30+
31+
offload_torch_model_to_cpu(model)

0 commit comments

Comments
 (0)