Skip to content

Commit 9a03e29

Browse files
committed
Moved models moving between devices inside scripts. Cleanup some operations.
1 parent c77b057 commit 9a03e29

File tree

10 files changed

+242
-132
lines changed

10 files changed

+242
-132
lines changed

model_navigator/commands/data_dump/samples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def _run(
216216
samples_to_npz(outputs, sample_path, batch_dim, raise_on_error=raise_on_error, num_samples=len(samples))
217217

218218
runner.deactivate()
219+
offload_model_to_cpu(model, framework)
219220

220221
return CommandOutput(
221222
status=CommandStatus.OK,

model_navigator/commands/export/exporters/torch2dynamo_onnx.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Export Torch model using dynamo."""
1515

16+
import gc
1617
import logging
1718
import pathlib
1819
from typing import Any, Dict, List, Optional
@@ -151,6 +152,18 @@ def expand_batch_dim(tensor, batch_dim, max_batch_size):
151152
exported_model.save(exported_model_path.as_posix())
152153
finally:
153154
root_logger.setLevel(original_loglevel)
155+
# Offload tensors to CPU
156+
for arg in args:
157+
if isinstance(arg, torch.Tensor):
158+
arg.cpu()
159+
for value in kwargs.values():
160+
if isinstance(value, torch.Tensor):
161+
value.cpu()
162+
163+
del args
164+
del kwargs
165+
gc.collect()
166+
torch.cuda.empty_cache()
154167

155168
_modify_onnx_io_names(exported_model_path, input_names, output_names, exported_model_path)
156169

model_navigator/commands/export/exporters/torch2exportedprogram.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Export Torch model using dynamo."""
1515

16+
import gc
1617
import pathlib
1718
from typing import Any, Dict, Optional
1819

@@ -59,6 +60,7 @@ def export(
5960
device_max_batch_size: Maximum batch size that fits on the device. Defaults to None.
6061
"""
6162
model = get_model()
63+
model.to(target_device)
6264

6365
if not navigator_workspace:
6466
navigator_workspace = pathlib.Path.cwd()
@@ -112,27 +114,43 @@ def export(
112114
dynamic_shapes.append(dynamic_shape_map)
113115

114116
try:
115-
exported_model = torch.export.export(
116-
model,
117-
args=tuple(args),
118-
kwargs=kwargs,
119-
dynamic_shapes=dynamic_shapes,
120-
**custom_args,
121-
)
122-
except Exception:
123-
exported_model = torch.export._trace._export(
124-
model,
125-
args=tuple(args),
126-
_allow_complex_guards_as_runtime_asserts=True,
127-
dynamic_shapes=dynamic_shapes,
128-
kwargs=kwargs,
129-
**custom_args,
130-
)
131-
132-
exported_model_path = pathlib.Path(exported_model_path)
133-
if not exported_model_path.is_absolute():
134-
exported_model_path = navigator_workspace / exported_model_path
135-
136-
torch.export.save(exported_model, exported_model_path.as_posix())
137-
138-
offload_torch_model_to_cpu(exported_model.module())
117+
try:
118+
exported_model = torch.export.export(
119+
model,
120+
args=tuple(args),
121+
kwargs=kwargs,
122+
dynamic_shapes=dynamic_shapes,
123+
**custom_args,
124+
)
125+
except Exception:
126+
exported_model = torch.export._trace._export(
127+
model,
128+
args=tuple(args),
129+
_allow_complex_guards_as_runtime_asserts=True,
130+
dynamic_shapes=dynamic_shapes,
131+
kwargs=kwargs,
132+
**custom_args,
133+
)
134+
135+
exported_model_path = pathlib.Path(exported_model_path)
136+
if not exported_model_path.is_absolute():
137+
exported_model_path = navigator_workspace / exported_model_path
138+
139+
torch.export.save(exported_model, exported_model_path.as_posix())
140+
finally:
141+
if exported_model is not None:
142+
offload_torch_model_to_cpu(exported_model.module())
143+
del exported_model
144+
145+
# Offload tensors to CPU
146+
for arg in args:
147+
if isinstance(arg, torch.Tensor):
148+
arg.cpu()
149+
for value in kwargs.values():
150+
if isinstance(value, torch.Tensor):
151+
value.cpu()
152+
153+
del args
154+
del kwargs
155+
gc.collect()
156+
torch.cuda.empty_cache()

model_navigator/commands/export/exporters/torch2onnx.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Export Torch model to ONNX model."""
1515

16+
import gc
1617
import inspect
1718
import pathlib
1819
from typing import Any, Dict, List, Mapping, Optional
@@ -72,54 +73,67 @@ def export(
7273
profiling_sample = load_samples("profiling_sample", navigator_workspace, batch_dim)[0]
7374
input_metadata = TensorMetadata.from_json(input_metadata)
7475

75-
dummy_input = {n: torch.from_numpy(val).to(export_device) for n, val in profiling_sample.items()}
76+
dummy_input_map = {n: torch.from_numpy(val).to(export_device) for n, val in profiling_sample.items()}
7677

7778
# adjust input dtypes to match input_metadata
7879
# TODO: Remove when torch.bfloat16 will be supported
80+
dummy_input = {}
7981
for n, spec in input_metadata.items():
8082
if not isinstance(spec.dtype, torch.dtype):
8183
torch_dtype = numpy_to_torch_dtype(spec.dtype)
8284
else:
8385
torch_dtype = spec.dtype
84-
dummy_input[n] = dummy_input[n].to(torch_dtype)
86+
dummy_input[n] = dummy_input_map[n].to(torch_dtype)
8587

8688
dummy_input = input_metadata.unflatten_sample(dummy_input)
8789

8890
# torch.onnx.export requires inputs to be a tuple or tensor
8991
if isinstance(dummy_input, Mapping):
9092
dummy_input = (dummy_input,)
9193

92-
forward_argspec = inspect.getfullargspec(model.forward)
93-
forward_args = forward_argspec.args[1:]
94-
9594
args_mapping, kwargs_mapping = input_metadata.pytree_metadata.get_names_mapping()
9695

96+
# Use inspect.signature instead of getfullargspec for more complete parameter information
97+
forward_signature = inspect.signature(model.forward)
98+
forward_params = list(forward_signature.parameters.keys())
99+
100+
args_count = len(args_mapping)
101+
forward_kwargs = forward_params[args_count:]
102+
97103
for argname in kwargs_mapping:
98-
assert argname in forward_args, f"Argument {argname} is not in forward argspec."
104+
assert argname in forward_kwargs, f"Argument {argname} is not in forward argspec."
99105

100106
input_names = []
101107
for args_names in args_mapping:
102108
input_names.extend(args_names)
103109

104-
for argname in forward_args:
110+
for argname in forward_kwargs:
105111
if argname in kwargs_mapping:
106112
input_names.extend(kwargs_mapping[argname])
107113

108114
exported_model_path = pathlib.Path(exported_model_path)
109115
if not exported_model_path.is_absolute():
110116
exported_model_path = navigator_workspace / exported_model_path
111117

112-
torch.onnx.export(
113-
model,
114-
args=dummy_input,
115-
f=exported_model_path.as_posix(),
116-
verbose=False,
117-
opset_version=opset,
118-
input_names=input_names,
119-
output_names=output_names,
120-
dynamic_axes=dynamic_axes,
121-
**custom_args,
122-
)
118+
try:
119+
torch.onnx.export(
120+
model,
121+
args=dummy_input,
122+
f=exported_model_path.as_posix(),
123+
verbose=False,
124+
opset_version=opset,
125+
input_names=input_names,
126+
output_names=output_names,
127+
dynamic_axes=dynamic_axes,
128+
**custom_args,
129+
)
130+
finally:
131+
for tensor in dummy_input_map.values():
132+
tensor.cpu()
133+
134+
del dummy_input_map
135+
gc.collect()
136+
torch.cuda.empty_cache()
123137

124138

125139
if __name__ == "__main__":

model_navigator/commands/export/exporters/torch2quantized_onnx.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Export PyTorch model to quantized ONNX using ModelOpt."""
1515

16+
import gc
1617
import inspect
1718
import pathlib
1819
from copy import deepcopy
@@ -109,57 +110,30 @@ def export(
109110
offload_torch_model_to_cpu(original_model)
110111

111112
try:
112-
# Move model copy to target device
113-
model_copy = model_copy.to(target_device)
114-
115113
# Load calibration samples
116114
LOGGER.info("Loading calibration samples")
117115
correctness_samples = load_samples("correctness_samples", navigator_workspace, batch_dim)
118116
if not correctness_samples:
119117
LOGGER.error("No correctness samples found for calibration")
120118
raise RuntimeError("No calibration samples found")
121119

122-
# Convert samples to PyTorch tensors
123-
torch_samples = []
124-
for sample in correctness_samples:
125-
sample_dict = {}
126-
for name, tensor in sample.items():
127-
torch_sample = torch.from_numpy(tensor)
128-
torch_sample = torch_sample.to(target_device)
129-
sample_dict[name] = torch_sample
130-
torch_samples.append(sample_dict)
131-
132-
calibration_data = [list(sample.values()) for sample in torch_samples]
133-
134-
# Define calibration function
135-
def forward_loop(model):
136-
for sample in calibration_data:
137-
model(*sample)
138-
139-
LOGGER.info("Using NVFP4_FP8_MHA_CONFIG quantization config for precision NVFP4")
140-
141-
# Run quantization
142-
LOGGER.info("Starting model quantization (this may take several minutes)...")
143-
quantized_model = mtq.quantize(
144-
model_copy,
145-
NVFP4_FP8_MHA_CONFIG,
146-
forward_loop,
147-
)
120+
quantized_model = _quantize_model(model_copy, target_device, correctness_samples)
148121

149122
LOGGER.info("Model quantization completed")
150123

151124
# Prepare input for ONNX export
152125
input_metadata = TensorMetadata.from_json(input_metadata)
153126
correct_sample = correctness_samples[0]
154-
dummy_input = {n: torch.from_numpy(val).to(target_device) for n, val in correct_sample.items()}
127+
dummy_input_map = {n: torch.from_numpy(val).to(target_device) for n, val in correct_sample.items()}
155128

156129
# Adjust input dtypes to match input_metadata
130+
dummy_input = {}
157131
for n, spec in input_metadata.items():
158132
if not isinstance(spec.dtype, torch.dtype):
159133
torch_dtype = numpy_to_torch_dtype(spec.dtype)
160134
else:
161135
torch_dtype = spec.dtype
162-
dummy_input[n] = dummy_input[n].to(torch_dtype).to(target_device)
136+
dummy_input[n] = dummy_input_map[n].to(torch_dtype).to(target_device)
163137

164138
dummy_input = input_metadata.unflatten_sample(dummy_input)
165139

@@ -168,17 +142,20 @@ def forward_loop(model):
168142
dummy_input = (dummy_input,)
169143

170144
# Get expected function signature for forward method
171-
forward_argspec = inspect.getfullargspec(model_copy.forward)
172-
forward_args = forward_argspec.args[1:] # Skip 'self'
145+
forward_signature = inspect.signature(model_copy.forward)
146+
forward_params = list(forward_signature.parameters.keys())
173147

174148
# Create input_names for ONNX model
175149
args_mapping, kwargs_mapping = input_metadata.pytree_metadata.get_names_mapping()
176150

151+
args_count = len(args_mapping)
152+
forward_kwargs = forward_params[args_count:]
153+
177154
input_names = []
178155
for args_names in args_mapping:
179156
input_names.extend(args_names)
180157

181-
for argname in forward_args:
158+
for argname in forward_kwargs:
182159
if argname in kwargs_mapping:
183160
input_names.extend(kwargs_mapping[argname])
184161
# Configure quantizers for ONNX export
@@ -226,12 +203,58 @@ def forward_loop(model):
226203

227204
LOGGER.info("Quantized ONNX export completed successfully")
228205

229-
# Clean up
230-
offload_torch_model_to_cpu(model_copy)
231-
offload_torch_model_to_cpu(quantized_model)
232206
except Exception as e:
233207
LOGGER.error(f"Error during quantized ONNX export: {str(e)}")
234208
raise
209+
finally:
210+
# Clean up
211+
if model_copy is not None:
212+
offload_torch_model_to_cpu(model_copy)
213+
del model_copy
214+
if quantized_model is not None:
215+
offload_torch_model_to_cpu(quantized_model)
216+
del quantized_model
217+
if dummy_input_map is not None:
218+
for tensor in dummy_input_map.values():
219+
tensor.cpu()
220+
del dummy_input_map
221+
222+
gc.collect()
223+
torch.cuda.empty_cache()
224+
225+
226+
def _quantize_model(model, target_device, correctness_samples):
227+
# Move model copy to target device
228+
model = model.to(target_device)
229+
230+
# Convert samples to PyTorch tensors
231+
torch_samples = []
232+
for sample in correctness_samples:
233+
sample_dict = {}
234+
for name, tensor in sample.items():
235+
torch_sample = torch.from_numpy(tensor)
236+
torch_sample = torch_sample.to(target_device)
237+
sample_dict[name] = torch_sample
238+
torch_samples.append(sample_dict)
239+
240+
calibration_data = [list(sample.values()) for sample in torch_samples]
241+
242+
# Define calibration function
243+
def forward_loop(model):
244+
for sample in calibration_data:
245+
model(*sample)
246+
247+
LOGGER.info("Using NVFP4_FP8_MHA_CONFIG quantization config for precision NVFP4")
248+
249+
# Run quantization
250+
LOGGER.info("Starting model quantization (this may take several minutes)...")
251+
quantized_model = mtq.quantize(
252+
model,
253+
NVFP4_FP8_MHA_CONFIG,
254+
forward_loop,
255+
)
256+
257+
return quantized_model
235258

236259

237260
if __name__ == "__main__":

0 commit comments

Comments
 (0)