Skip to content

Commit 26c972a

Browse files
authored
[Offload] Require offload_folder when performing disk offloading (#602)
* full inverse Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests to ensure that pointers are never rematerialized Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * require offload_dir, refactor tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * small change Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * Fix test_to_accelerate_module to provide offload_dir for disk offloading * fix quality Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix rebase Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 16ce668 commit 26c972a

File tree

12 files changed

+310
-238
lines changed

12 files changed

+310
-238
lines changed

src/compressed_tensors/offload/cache/disk.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import os
5-
import tempfile
65
from typing import TYPE_CHECKING, Optional
76

87
import torch
@@ -39,7 +38,12 @@ class DiskCache(OffloadCache):
3938

4039
def __init__(self, onload_device: torch.device, offload_dir: Optional[str] = None):
4140
super().__init__(onload_device)
42-
self.offload_dir = offload_dir or tempfile.mkdtemp()
41+
if offload_dir is None:
42+
raise ValueError(
43+
"Must provide an `offload_dir` to perform disk offloading "
44+
"(add `offload_folder` argument to `from_pretrained`)"
45+
)
46+
self.offload_dir = offload_dir
4347

4448
def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor | None:
4549
"""
@@ -139,7 +143,11 @@ def create_checkpoint_symlink(
139143
offload_dir: str | os.PathLike | None,
140144
) -> None:
141145
assert is_rank0(), "Must call on rank 0 to avoid id collisions between ranks"
142-
offload_dir = offload_dir or tempfile.mkdtemp()
146+
if offload_dir is None:
147+
raise ValueError(
148+
"Must provide an `offload_dir` to perform disk offloading "
149+
"(add `offload_folder` argument to `from_pretrained`)"
150+
)
143151
file_name = f"{cls._new_file_prefix}{id(offloaded)}.safetensors"
144152
file_path = os.path.join(offload_dir, file_name)
145153

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
6+
import pytest
7+
from compressed_tensors.offload import OffloadCache
8+
9+
10+
@pytest.fixture()
11+
def offload_cache(offload_device, onload_device, tmp_path):
12+
if offload_device == "disk":
13+
offload_dir = str(tmp_path / "offload_dir")
14+
os.makedirs(offload_dir)
15+
return OffloadCache.cls_from_device(offload_device)(
16+
onload_device, offload_dir=offload_dir
17+
)
18+
else:
19+
return OffloadCache.cls_from_device(offload_device)(onload_device)

tests/test_offload/cache/helpers.py

Lines changed: 59 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,57 +5,51 @@
55
from weakref import ref
66

77
import torch
8-
from compressed_tensors.offload import OffloadCache
98
from tests.test_offload.conftest import assert_device_equal, assert_tensor_equal
109

1110

12-
def _test_onloading(offload_device, onload_device):
13-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
11+
def _test_onloading(offload_device, onload_device, offload_cache):
1412
tensor = torch.ones(10)
15-
cache["weight"] = tensor
16-
onloaded = cache["weight"]
13+
offload_cache["weight"] = tensor
14+
onloaded = offload_cache["weight"]
1715

1816
assert type(onloaded) is type(tensor)
1917
assert_tensor_equal(onloaded, tensor, onload_device)
2018

2119

22-
def _test_garbage_collect(offload_device, onload_device):
23-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
24-
cache["weight"] = torch.ones(10)
25-
onloaded = cache["weight"]
20+
def _test_garbage_collect(offload_device, onload_device, offload_cache):
21+
offload_cache["weight"] = torch.ones(10)
22+
onloaded = offload_cache["weight"]
2623

2724
onloaded_ref = ref(onloaded)
2825
del onloaded
2926
gc.collect()
3027
assert onloaded_ref() is None
3128

3229

33-
def _test_offload(offload_device, onload_device):
34-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
30+
def _test_offload(offload_device, onload_device, offload_cache):
3531
tensor = torch.ones(10, device=onload_device)
36-
offloaded = cache.offload(tensor)
32+
offloaded = offload_cache.offload(tensor)
3733
assert_device_equal(offloaded.device, offload_device)
3834
assert_tensor_equal(offloaded, tensor, offload_device)
3935

4036

41-
def _test_onload(offload_device, onload_device):
42-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
37+
def _test_onload(offload_device, onload_device, offload_cache):
4338
tensor = torch.ones(10, device=onload_device)
44-
onloaded = cache.onload(cache.offload(tensor))
39+
onloaded = offload_cache.onload(offload_cache.offload(tensor))
4540
assert_device_equal(onloaded.device, onload_device)
4641
assert_tensor_equal(onloaded, tensor, onload_device)
4742

4843

49-
def _test_disable_offloading(offload_device, onload_device):
50-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
51-
cache["weight"] = torch.ones(10)
44+
def _test_disable_offloading(offload_device, onload_device, offload_cache):
45+
offload_cache["weight"] = torch.ones(10)
5246

53-
outside_onloaded = cache["weight"]
47+
outside_onloaded = offload_cache["weight"]
5448
outside_onloaded_ref = ref(outside_onloaded)
5549
assert_device_equal(outside_onloaded.device, onload_device)
5650

57-
with cache.disable_offloading():
58-
inside_onloaded = cache["weight"]
51+
with offload_cache.disable_offloading():
52+
inside_onloaded = offload_cache["weight"]
5953
inside_onloaded_ref = ref(inside_onloaded)
6054
assert_device_equal(inside_onloaded.device, onload_device)
6155

@@ -70,26 +64,24 @@ def _test_disable_offloading(offload_device, onload_device):
7064
assert inside_onloaded_ref() is None
7165

7266

73-
def _test_disable_onloading(offload_device, onload_device):
74-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
67+
def _test_disable_onloading(offload_device, onload_device, offload_cache):
7568
tensor = torch.ones(10)
76-
cache.offloaded_values["weight"] = tensor
69+
offload_cache.offloaded_values["weight"] = tensor
7770

78-
with cache.disable_onloading():
79-
onloaded = cache["weight"]
71+
with offload_cache.disable_onloading():
72+
onloaded = offload_cache["weight"]
8073
assert onloaded is tensor
8174

8275
assert onloaded is tensor
8376

8477

85-
def _test_delete(offload_device, onload_device):
86-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
87-
cache["weight"] = torch.ones(10)
88-
onloaded = cache["weight"]
78+
def _test_delete(offload_device, onload_device, offload_cache):
79+
offload_cache["weight"] = torch.ones(10)
80+
onloaded = offload_cache["weight"]
8981
onloaded_ref = ref(onloaded)
9082

91-
with cache.disable_offloading():
92-
del cache["weight"]
83+
with offload_cache.disable_offloading():
84+
del offload_cache["weight"]
9385
del onloaded
9486
gc.collect()
9587

@@ -98,66 +90,69 @@ def _test_delete(offload_device, onload_device):
9890
assert onloaded_ref() is None
9991

10092

101-
def _test_shared_attributes(offload_device, onload_device):
102-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
103-
assert cache.offloading_disabled is cache.__class__.offloading_disabled
104-
assert cache.onloading_disabled is cache.__class__.onloading_disabled
105-
assert cache.keep_onloaded_values is cache.__class__.keep_onloaded_values
93+
def _test_shared_attributes(offload_device, onload_device, offload_cache):
94+
assert (
95+
offload_cache.offloading_disabled is offload_cache.__class__.offloading_disabled
96+
)
97+
assert (
98+
offload_cache.onloading_disabled is offload_cache.__class__.onloading_disabled
99+
)
100+
assert (
101+
offload_cache.keep_onloaded_values
102+
is offload_cache.__class__.keep_onloaded_values
103+
)
106104

107-
assert not hasattr(cache.__class__, "onload_device")
108-
assert not hasattr(cache.__class__, "offloaded_values")
105+
assert not hasattr(offload_cache.__class__, "onload_device")
106+
assert not hasattr(offload_cache.__class__, "offloaded_values")
109107

110108

111-
def _test_tensor_subclass(offload_device, onload_device):
109+
def _test_tensor_subclass(offload_device, onload_device, offload_cache):
112110
tensor = torch.ones(10)
113111
param = torch.nn.Parameter(torch.ones(10), requires_grad=False)
114112
buffer = torch.nn.Buffer(torch.ones(10))
115113

116-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
117-
cache["tensor"] = tensor
118-
cache["param"] = param
119-
cache["buffer"] = buffer
114+
offload_cache["tensor"] = tensor
115+
offload_cache["param"] = param
116+
offload_cache["buffer"] = buffer
120117

121-
assert_tensor_equal(cache["tensor"], tensor, onload_device)
122-
assert_tensor_equal(cache["param"], param, onload_device)
123-
assert_tensor_equal(cache["buffer"], buffer, onload_device)
118+
assert_tensor_equal(offload_cache["tensor"], tensor, onload_device)
119+
assert_tensor_equal(offload_cache["param"], param, onload_device)
120+
assert_tensor_equal(offload_cache["buffer"], buffer, onload_device)
124121

125-
with cache.disable_onloading():
126-
assert_tensor_equal(cache["tensor"], tensor, offload_device)
127-
assert_tensor_equal(cache["param"], param, offload_device)
128-
assert_tensor_equal(cache["buffer"], buffer, offload_device)
122+
with offload_cache.disable_onloading():
123+
assert_tensor_equal(offload_cache["tensor"], tensor, offload_device)
124+
assert_tensor_equal(offload_cache["param"], param, offload_device)
125+
assert_tensor_equal(offload_cache["buffer"], buffer, offload_device)
129126

130127

131-
def _test_update_offload(offload_device, onload_device):
132-
cache = OffloadCache.cls_from_device(offload_device)(onload_device)
133-
128+
def _test_update_offload(offload_device, onload_device, offload_cache):
134129
# Create initial tensor and offload it
135130
initial_data = torch.ones(10, device=onload_device)
136-
cache["weight"] = initial_data
131+
offload_cache["weight"] = initial_data
137132

138133
# Verify initial value
139-
onloaded = cache["weight"]
134+
onloaded = offload_cache["weight"]
140135
assert_tensor_equal(onloaded, initial_data, onload_device)
141136

142137
# Update with new data
143138
new_data = torch.ones(10, device=onload_device) * 2.0
144-
cache["weight"] = new_data
139+
offload_cache["weight"] = new_data
145140

146141
# Verify update worked
147-
updated_onloaded = cache["weight"]
142+
updated_onloaded = offload_cache["weight"]
148143
assert_tensor_equal(updated_onloaded, new_data, onload_device)
149144

150145
# Verify offloaded tensor was updated in place (not replaced)
151-
with cache.disable_onloading():
152-
offloaded = cache["weight"]
146+
with offload_cache.disable_onloading():
147+
offloaded = offload_cache["weight"]
153148
assert_tensor_equal(offloaded, new_data, offload_device)
154149

155150
# Test update with disable_offloading context
156-
with cache.disable_offloading():
157-
cache["weight"] = torch.ones(10, device=onload_device) * 3.0
158-
cached_onloaded = cache["weight"]
151+
with offload_cache.disable_offloading():
152+
offload_cache["weight"] = torch.ones(10, device=onload_device) * 3.0
153+
cached_onloaded = offload_cache["weight"]
159154
assert_tensor_equal(cached_onloaded, torch.ones(10) * 3.0, onload_device)
160155

161156
# Verify update persisted after context exit
162-
final_onloaded = cache["weight"]
157+
final_onloaded = offload_cache["weight"]
163158
assert_tensor_equal(final_onloaded, torch.ones(10) * 3.0, onload_device)

tests/test_offload/cache/test_cpu.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,71 +13,70 @@
1313
_test_onloading,
1414
_test_shared_attributes,
1515
_test_tensor_subclass,
16-
_test_update_offload,
1716
)
1817
from tests.testing_utils import requires_gpu
1918

2019

21-
ONLOAD_DEVICE = torch.device("cuda")
22-
OFFLOAD_DEVICE = torch.device("cpu")
20+
@pytest.fixture()
21+
def onload_device():
22+
return torch.device("cuda")
2323

2424

25-
@pytest.mark.unit
26-
@requires_gpu
27-
def test_delete():
28-
_test_delete(OFFLOAD_DEVICE, ONLOAD_DEVICE)
25+
@pytest.fixture()
26+
def offload_device():
27+
return torch.device("cpu")
2928

3029

3130
@pytest.mark.unit
3231
@requires_gpu
33-
def test_disable_offloading():
34-
_test_disable_offloading(OFFLOAD_DEVICE, ONLOAD_DEVICE)
32+
def test_delete(offload_device, onload_device, offload_cache):
33+
_test_delete(offload_device, onload_device, offload_cache)
3534

3635

3736
@pytest.mark.unit
3837
@requires_gpu
39-
def test_disable_onloading():
40-
_test_disable_onloading(OFFLOAD_DEVICE, ONLOAD_DEVICE)
38+
def test_disable_offloading(offload_device, onload_device, offload_cache):
39+
_test_disable_offloading(offload_device, onload_device, offload_cache)
4140

4241

4342
@pytest.mark.unit
4443
@requires_gpu
45-
def test_garbage_collect():
46-
_test_garbage_collect(OFFLOAD_DEVICE, ONLOAD_DEVICE)
44+
def test_disable_onloading(offload_device, onload_device, offload_cache):
45+
_test_disable_onloading(offload_device, onload_device, offload_cache)
4746

4847

4948
@pytest.mark.unit
5049
@requires_gpu
51-
def test_offload():
52-
_test_offload(OFFLOAD_DEVICE, ONLOAD_DEVICE)
50+
def test_garbage_collect(offload_device, onload_device, offload_cache):
51+
_test_garbage_collect(offload_device, onload_device, offload_cache)
5352

5453

5554
@pytest.mark.unit
5655
@requires_gpu
57-
@requires_gpu
58-
def test_onload():
59-
_test_onload(OFFLOAD_DEVICE, ONLOAD_DEVICE)
56+
def test_offload(offload_device, onload_device, offload_cache):
57+
_test_offload(offload_device, onload_device, offload_cache)
6058

6159

6260
@pytest.mark.unit
6361
@requires_gpu
64-
def test_onloading():
65-
_test_onloading(OFFLOAD_DEVICE, ONLOAD_DEVICE)
62+
@requires_gpu
63+
def test_onload(offload_device, onload_device, offload_cache):
64+
_test_onload(offload_device, onload_device, offload_cache)
6665

6766

6867
@pytest.mark.unit
6968
@requires_gpu
70-
def test_shared_attributes():
71-
_test_shared_attributes(OFFLOAD_DEVICE, ONLOAD_DEVICE)
69+
def test_onloading(offload_device, onload_device, offload_cache):
70+
_test_onloading(offload_device, onload_device, offload_cache)
7271

7372

7473
@pytest.mark.unit
7574
@requires_gpu
76-
def test_tensor_subclass():
77-
_test_tensor_subclass(OFFLOAD_DEVICE, ONLOAD_DEVICE)
75+
def test_shared_attributes(offload_device, onload_device, offload_cache):
76+
_test_shared_attributes(offload_device, onload_device, offload_cache)
7877

7978

8079
@pytest.mark.unit
8180
@requires_gpu
82-
def test_update_offload():
83-
_test_update_offload(OFFLOAD_DEVICE, ONLOAD_DEVICE)
81+
def test_tensor_subclass(offload_device, onload_device, offload_cache):
82+
_test_tensor_subclass(offload_device, onload_device, offload_cache)

0 commit comments

Comments
 (0)