55from weakref import ref
66
77import torch
8- from compressed_tensors .offload import OffloadCache
98from tests .test_offload .conftest import assert_device_equal , assert_tensor_equal
109
1110
12- def _test_onloading (offload_device , onload_device ):
13- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
11+ def _test_onloading (offload_device , onload_device , offload_cache ):
1412 tensor = torch .ones (10 )
15- cache ["weight" ] = tensor
16- onloaded = cache ["weight" ]
13+ offload_cache ["weight" ] = tensor
14+ onloaded = offload_cache ["weight" ]
1715
1816 assert type (onloaded ) is type (tensor )
1917 assert_tensor_equal (onloaded , tensor , onload_device )
2018
2119
22- def _test_garbage_collect (offload_device , onload_device ):
23- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
24- cache ["weight" ] = torch .ones (10 )
25- onloaded = cache ["weight" ]
20+ def _test_garbage_collect (offload_device , onload_device , offload_cache ):
21+ offload_cache ["weight" ] = torch .ones (10 )
22+ onloaded = offload_cache ["weight" ]
2623
2724 onloaded_ref = ref (onloaded )
2825 del onloaded
2926 gc .collect ()
3027 assert onloaded_ref () is None
3128
3229
33- def _test_offload (offload_device , onload_device ):
34- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
30+ def _test_offload (offload_device , onload_device , offload_cache ):
3531 tensor = torch .ones (10 , device = onload_device )
36- offloaded = cache .offload (tensor )
32+ offloaded = offload_cache .offload (tensor )
3733 assert_device_equal (offloaded .device , offload_device )
3834 assert_tensor_equal (offloaded , tensor , offload_device )
3935
4036
41- def _test_onload (offload_device , onload_device ):
42- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
37+ def _test_onload (offload_device , onload_device , offload_cache ):
4338 tensor = torch .ones (10 , device = onload_device )
44- onloaded = cache .onload (cache .offload (tensor ))
39+ onloaded = offload_cache .onload (offload_cache .offload (tensor ))
4540 assert_device_equal (onloaded .device , onload_device )
4641 assert_tensor_equal (onloaded , tensor , onload_device )
4742
4843
49- def _test_disable_offloading (offload_device , onload_device ):
50- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
51- cache ["weight" ] = torch .ones (10 )
44+ def _test_disable_offloading (offload_device , onload_device , offload_cache ):
45+ offload_cache ["weight" ] = torch .ones (10 )
5246
53- outside_onloaded = cache ["weight" ]
47+ outside_onloaded = offload_cache ["weight" ]
5448 outside_onloaded_ref = ref (outside_onloaded )
5549 assert_device_equal (outside_onloaded .device , onload_device )
5650
57- with cache .disable_offloading ():
58- inside_onloaded = cache ["weight" ]
51+ with offload_cache .disable_offloading ():
52+ inside_onloaded = offload_cache ["weight" ]
5953 inside_onloaded_ref = ref (inside_onloaded )
6054 assert_device_equal (inside_onloaded .device , onload_device )
6155
@@ -70,26 +64,24 @@ def _test_disable_offloading(offload_device, onload_device):
7064 assert inside_onloaded_ref () is None
7165
7266
73- def _test_disable_onloading (offload_device , onload_device ):
74- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
67+ def _test_disable_onloading (offload_device , onload_device , offload_cache ):
7568 tensor = torch .ones (10 )
76- cache .offloaded_values ["weight" ] = tensor
69+ offload_cache .offloaded_values ["weight" ] = tensor
7770
78- with cache .disable_onloading ():
79- onloaded = cache ["weight" ]
71+ with offload_cache .disable_onloading ():
72+ onloaded = offload_cache ["weight" ]
8073 assert onloaded is tensor
8174
8275 assert onloaded is tensor
8376
8477
85- def _test_delete (offload_device , onload_device ):
86- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
87- cache ["weight" ] = torch .ones (10 )
88- onloaded = cache ["weight" ]
78+ def _test_delete (offload_device , onload_device , offload_cache ):
79+ offload_cache ["weight" ] = torch .ones (10 )
80+ onloaded = offload_cache ["weight" ]
8981 onloaded_ref = ref (onloaded )
9082
91- with cache .disable_offloading ():
92- del cache ["weight" ]
83+ with offload_cache .disable_offloading ():
84+ del offload_cache ["weight" ]
9385 del onloaded
9486 gc .collect ()
9587
@@ -98,66 +90,69 @@ def _test_delete(offload_device, onload_device):
9890 assert onloaded_ref () is None
9991
10092
101- def _test_shared_attributes (offload_device , onload_device ):
102- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
103- assert cache .offloading_disabled is cache .__class__ .offloading_disabled
104- assert cache .onloading_disabled is cache .__class__ .onloading_disabled
105- assert cache .keep_onloaded_values is cache .__class__ .keep_onloaded_values
93+ def _test_shared_attributes (offload_device , onload_device , offload_cache ):
94+ assert (
95+ offload_cache .offloading_disabled is offload_cache .__class__ .offloading_disabled
96+ )
97+ assert (
98+ offload_cache .onloading_disabled is offload_cache .__class__ .onloading_disabled
99+ )
100+ assert (
101+ offload_cache .keep_onloaded_values
102+ is offload_cache .__class__ .keep_onloaded_values
103+ )
106104
107- assert not hasattr (cache .__class__ , "onload_device" )
108- assert not hasattr (cache .__class__ , "offloaded_values" )
105+ assert not hasattr (offload_cache .__class__ , "onload_device" )
106+ assert not hasattr (offload_cache .__class__ , "offloaded_values" )
109107
110108
111- def _test_tensor_subclass (offload_device , onload_device ):
109+ def _test_tensor_subclass (offload_device , onload_device , offload_cache ):
112110 tensor = torch .ones (10 )
113111 param = torch .nn .Parameter (torch .ones (10 ), requires_grad = False )
114112 buffer = torch .nn .Buffer (torch .ones (10 ))
115113
116- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
117- cache ["tensor" ] = tensor
118- cache ["param" ] = param
119- cache ["buffer" ] = buffer
114+ offload_cache ["tensor" ] = tensor
115+ offload_cache ["param" ] = param
116+ offload_cache ["buffer" ] = buffer
120117
121- assert_tensor_equal (cache ["tensor" ], tensor , onload_device )
122- assert_tensor_equal (cache ["param" ], param , onload_device )
123- assert_tensor_equal (cache ["buffer" ], buffer , onload_device )
118+ assert_tensor_equal (offload_cache ["tensor" ], tensor , onload_device )
119+ assert_tensor_equal (offload_cache ["param" ], param , onload_device )
120+ assert_tensor_equal (offload_cache ["buffer" ], buffer , onload_device )
124121
125- with cache .disable_onloading ():
126- assert_tensor_equal (cache ["tensor" ], tensor , offload_device )
127- assert_tensor_equal (cache ["param" ], param , offload_device )
128- assert_tensor_equal (cache ["buffer" ], buffer , offload_device )
122+ with offload_cache .disable_onloading ():
123+ assert_tensor_equal (offload_cache ["tensor" ], tensor , offload_device )
124+ assert_tensor_equal (offload_cache ["param" ], param , offload_device )
125+ assert_tensor_equal (offload_cache ["buffer" ], buffer , offload_device )
129126
130127
131- def _test_update_offload (offload_device , onload_device ):
132- cache = OffloadCache .cls_from_device (offload_device )(onload_device )
133-
128+ def _test_update_offload (offload_device , onload_device , offload_cache ):
134129 # Create initial tensor and offload it
135130 initial_data = torch .ones (10 , device = onload_device )
136- cache ["weight" ] = initial_data
131+ offload_cache ["weight" ] = initial_data
137132
138133 # Verify initial value
139- onloaded = cache ["weight" ]
134+ onloaded = offload_cache ["weight" ]
140135 assert_tensor_equal (onloaded , initial_data , onload_device )
141136
142137 # Update with new data
143138 new_data = torch .ones (10 , device = onload_device ) * 2.0
144- cache ["weight" ] = new_data
139+ offload_cache ["weight" ] = new_data
145140
146141 # Verify update worked
147- updated_onloaded = cache ["weight" ]
142+ updated_onloaded = offload_cache ["weight" ]
148143 assert_tensor_equal (updated_onloaded , new_data , onload_device )
149144
150145 # Verify offloaded tensor was updated in place (not replaced)
151- with cache .disable_onloading ():
152- offloaded = cache ["weight" ]
146+ with offload_cache .disable_onloading ():
147+ offloaded = offload_cache ["weight" ]
153148 assert_tensor_equal (offloaded , new_data , offload_device )
154149
155150 # Test update with disable_offloading context
156- with cache .disable_offloading ():
157- cache ["weight" ] = torch .ones (10 , device = onload_device ) * 3.0
158- cached_onloaded = cache ["weight" ]
151+ with offload_cache .disable_offloading ():
152+ offload_cache ["weight" ] = torch .ones (10 , device = onload_device ) * 3.0
153+ cached_onloaded = offload_cache ["weight" ]
159154 assert_tensor_equal (cached_onloaded , torch .ones (10 ) * 3.0 , onload_device )
160155
161156 # Verify update persisted after context exit
162- final_onloaded = cache ["weight" ]
157+ final_onloaded = offload_cache ["weight" ]
163158 assert_tensor_equal (final_onloaded , torch .ones (10 ) * 3.0 , onload_device )
0 commit comments