Skip to content

Commit 68f1a7d

Browse files
[Refactor] Rename offload_model to set_onload_device (#643)
* [Refactor] Rename offload_model to set_onload_device - Add set_onload_device as the canonical function (replaces offload_model) - Deprecate offload_model using @deprecated decorator pointing to set_onload_device - Remove offload_device param from set_onload_device (was already ignored with warning) - Update all internal usages and tests Part of vllm-project/llm-compressor#2483 * [Docs] Update README references from offload_model to set_onload_device * Run make commands to apply formatting --------- Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 516354e commit 68f1a7d

File tree

12 files changed

+43
-34
lines changed

12 files changed

+43
-34
lines changed

src/compressed_tensors/offload/README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Offloads tensors to CPU RAM. Onloading is a standard `.to(device)` call from CPU
128128

129129
#### `DeviceCache``cache/device.py`
130130

131-
Offloads tensors to a CUDA device. Onloading is typically a no-op (the tensor is already on device), but handles the case where `onload_device` is changed after initialization (e.g., during `offload_model` reconfiguration).
131+
Offloads tensors to a CUDA device. Onloading is typically a no-op (the tensor is already on device), but handles the case where `onload_device` is changed after initialization (e.g., during `set_onload_device` reconfiguration).
132132

133133
- **offload**: moves tensor to the device (`self.offload_device = self.onload_device` at init).
134134
- **onload**: `send_tensors(offloaded, device=self.onload_device)`.
@@ -214,7 +214,7 @@ The primary function for attaching offloading to a single `torch.nn.Module`. It:
214214
offload_module(layer, onload_device="cuda:0", offload_device="cpu")
215215
```
216216

217-
**When to use:** when you want fine-grained control over which specific modules are offloaded. For model-wide dispatch, prefer `dispatch_model` or `offload_model`.
217+
**When to use:** when you want fine-grained control over which specific modules are offloaded. For model-wide dispatch, prefer `dispatch_model` or `set_onload_device`.
218218

219219
> **Note:** Raises `ValueError` if the module is already offloaded. Call `remove_module_offload` first.
220220
@@ -273,17 +273,19 @@ model = dispatch_model(model, device_memory={torch.device("cuda:0"): 16e9})
273273

274274
---
275275

276-
#### `offload_model(model, onload_device, offload_device=None)`
276+
#### `set_onload_device(model, onload_device)`
277277

278278
A lighter-weight dispatch that moves all modules in a model to the same `onload_device`, without changing where weights are stored. For modules not yet offloaded, it offloads them to their current device.
279279

280280
```python
281281
# Move all execution to cuda:0, keeping offloads unchanged
282-
model = offload_model(model, onload_device="cuda:0")
282+
model = set_onload_device(model, onload_device="cuda:0")
283283
```
284284

285285
**When to use:** when you have already loaded a model with weights in the right place (e.g., via `load_offloaded_model`) and just need to set the execution device. Less powerful than `dispatch_model` but simpler.
286286

287+
> **Note:** `offload_model` is a deprecated alias for this function.
288+
287289
---
288290

289291
#### `dispatch_with_map(model, device_map, offload_dir=None)`
@@ -684,7 +686,7 @@ compressed_tensors.offload
684686
├── load.py load_offloaded_model()
685687
│ └── calls from_accelerate after loading
686688
687-
├── dispatch.py dispatch_model(), offload_model(), dispatch_with_map()
689+
├── dispatch.py dispatch_model(), set_onload_device(), dispatch_with_map()
688690
│ └── calls offload_module() for each module
689691
690692
├── module.py offload_module(), remove_module_offload()

src/compressed_tensors/offload/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_device_map,
1515
offload_model,
1616
remove_dispatch,
17+
set_onload_device,
1718
)
1819
from compressed_tensors.offload.dist_utils import (
1920
as_broadcastable,
@@ -29,7 +30,8 @@
2930

3031
__all__ = [
3132
# dispatch models
32-
"offload_model",
33+
"set_onload_device",
34+
"offload_model", # deprecated, use set_onload_device
3335
"dispatch_model",
3436
"remove_dispatch",
3537
"dispatch_with_map",

src/compressed_tensors/offload/cache/dist_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor | None:
2828
if dist.get_rank() == 0:
2929
# create shared memory cpu tensor
3030
tensor = super().offload(tensor).share_memory_()
31-
(handle, filename, nbytes) = tensor.untyped_storage()._share_filename_cpu_()
31+
handle, filename, nbytes = tensor.untyped_storage()._share_filename_cpu_()
3232
broadcast_obj = [handle, filename, nbytes]
3333
else:
3434
broadcast_obj = [None, None, None]

src/compressed_tensors/offload/dispatch.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
)
1818
from compressed_tensors.utils import getattr_chain
1919
from compressed_tensors.utils.binary_search import SearchFailureError, max_binary_search
20+
from compressed_tensors.utils.helpers import deprecated
2021
from loguru import logger
2122
from transformers import PreTrainedModel
2223

2324

2425
__all__ = [
26+
"set_onload_device",
2527
"offload_model",
2628
"dispatch_with_map",
2729
"get_device_map",
@@ -35,28 +37,19 @@
3537
DeviceMap = dict[str, tuple[torch.device | None, torch.device | str | None]]
3638

3739

38-
def offload_model(
40+
def set_onload_device(
3941
model: ModelType,
4042
onload_device: torch.device | str,
41-
offload_device: Any = None,
4243
) -> ModelType:
4344
"""
4445
Modify the dispatch of a model to onload to the provided `onload_device`. Existing
45-
offloaded tensors will not be modified. If a module is not offloaded, it will be
46-
offloaded to the provided `offload_device`.
46+
offloaded tensors will not be modified. If a module is not already offloaded, it
47+
will be offloaded to its current device.
4748
4849
:param model: model to dispatch
4950
:param onload_device: device to move weights to during forward pass
50-
:param offload_device: device to offload weights to, if not already offloaded
5151
:return: dispatched model
5252
"""
53-
if offload_device is not None:
54-
logger.warning(
55-
"`offload_model` now keeps the same offload device that model was loaded "
56-
"on. Please specify offload by loading the model on its offload device(s)"
57-
)
58-
59-
# offload modules in place
6053
for module in model.modules():
6154
if isinstance(module._parameters, OffloadCache):
6255
module._parameters.onload_device = onload_device
@@ -68,6 +61,19 @@ def offload_model(
6861
return model
6962

7063

64+
@deprecated("set_onload_device")
65+
def offload_model(
66+
model: ModelType,
67+
onload_device: torch.device | str,
68+
offload_device: Any = None,
69+
) -> ModelType:
70+
"""
71+
.. deprecated::
72+
Use :func:`set_onload_device` instead.
73+
"""
74+
return set_onload_device(model, onload_device)
75+
76+
7177
def dispatch_with_map(
7278
model: torch.nn.Module,
7379
device_map: DeviceMap,

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
137137
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
138138
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
139139
output = output.view(input.shape[0], input.shape[1], -1)
140-
(input, output) = (output, input)
140+
input, output = (output, input)
141141
assert input.shape[1] == K
142142
del output
143143

src/compressed_tensors/utils/offload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
disable_offloading,
2222
get_execution_device,
2323
get_offloaded_device,
24-
offload_model,
2524
register_offload_module,
2625
remove_dispatch,
26+
set_onload_device,
2727
update_offload_parameter,
2828
)
2929
from compressed_tensors.utils.helpers import deprecated
@@ -134,7 +134,7 @@ def delete_offload_module(base: torch.nn.Module, name: str):
134134
delattr(base, name)
135135

136136

137-
@deprecated("compressed_tensors.offload::offload_model")
137+
@deprecated("compressed_tensors.offload::set_onload_device")
138138
def offloaded_dispatch(
139139
module: torch.nn.Module,
140140
execution_device: torch.device,
@@ -152,7 +152,7 @@ def offloaded_dispatch(
152152
raise ValueError(
153153
"Passing offload_device to offloaded_dispatch is no longer supported"
154154
)
155-
offload_model(module, execution_device)
155+
set_onload_device(module, execution_device)
156156

157157

158158
@deprecated("compressed_tensors.offload::align_module_device")

src/compressed_tensors/utils/semi_structured_conversions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212

13-
1413
__all__ = [
1514
"sparse_semi_structured_from_dense_cutlass",
1615
"sparse_semi_structured_to_dense_cutlass",

tests/test_offload/test_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from compressed_tensors.offload.dispatch import (
1010
dispatch_model,
1111
get_device_memory,
12-
offload_model,
12+
set_onload_device,
1313
)
1414
from compressed_tensors.offload.utils import module_size
1515
from tests.testing_utils import requires_gpu
@@ -190,7 +190,7 @@ def test_offload_and_dispatch_model(model_id):
190190

191191
# offload entire model
192192
model.to("cpu")
193-
model = offload_model(model, "cuda:0")
193+
model = set_onload_device(model, "cuda:0")
194194
offloaded_logits = model(**sample).logits
195195
for module in model.modules():
196196
assert_module_offloaded(module, "cuda:0", torch.device("cpu"))

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
import torch
8-
from compressed_tensors.offload import offload_model
8+
from compressed_tensors.offload import set_onload_device
99
from compressed_tensors.quantization import (
1010
FP8_E4M3_DATA,
1111
ActivationOrdering,
@@ -108,7 +108,7 @@ def test_initialize_module_for_quantization(
108108
def test_initialize_module_for_quantization_offloaded(
109109
create_quantization_scheme, weights, input_activations, layer
110110
):
111-
offload_model(layer, "cuda:0")
111+
set_onload_device(layer, "cuda:0")
112112

113113
test_initialize_module_for_quantization(
114114
create_quantization_scheme,

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
import torch
6-
from compressed_tensors.offload import offload_model
6+
from compressed_tensors.offload import set_onload_device
77
from compressed_tensors.transform import (
88
TransformArgs,
99
TransformConfig,
@@ -87,7 +87,7 @@ def test_correctness_model(type, randomize, input_batch_size, model_apply, offlo
8787
# load model
8888
model = model_apply[0]
8989
if offload:
90-
offload_model(model, torch.device("cuda"))
90+
set_onload_device(model, torch.device("cuda"))
9191

9292
# get output
9393
input = torch.rand((input_batch_size, 5, model.fcs[0].in_features))

0 commit comments

Comments
 (0)