@@ -36,7 +36,8 @@ def mock_debug_logger():
3636@pytest .fixture
3737def defragmenter (mock_config , mock_debug_logger ):
3838 """Create OnlineDefragmenter instance"""
39- return OnlineDefragmenter ()
39+ kv_caches = ((torch .empty (0 , device = 'meta' ), torch .empty (0 , device = 'meta' )), )
40+ return OnlineDefragmenter (kv_caches , block_size = 0 )
4041
4142
4243class TestOnlineDefragmenter :
@@ -174,7 +175,9 @@ def test_free_blocks_generator(self, defragmenter):
174175 def test_defragment_disabled (self , mock_config , mock_debug_logger ):
175176 """Test defragmentation when disabled"""
176177 mock_config .defrag = False
177- defrag = OnlineDefragmenter ()
178+
179+ kv_caches = ((torch .empty (0 , device = 'meta' ), torch .empty (0 , device = 'meta' )), )
180+ defrag = OnlineDefragmenter (kv_caches , 0 )
178181
179182 defrag .use_block (100 )
180183 defrag .defragment ()
@@ -198,11 +201,11 @@ def test_defragment_below_threshold(self, defragmenter):
198201
199202 max_before = max (defragmenter .used_blocks .keys ())
200203 defragmenter ._extend_mapping_table (max_before )
201- defragmenter .cache_utils = MagicMock ()
204+ defragmenter ._swap = MagicMock ()
202205 defragmenter .defragment ()
203206
204207 # Should not trigger defragmentation
205- defragmenter .cache_utils . swap .assert_not_called ()
208+ defragmenter ._swap .assert_not_called ()
206209 assert max (defragmenter .used_blocks .keys ()) == max_before
207210
208211 def test_defragment_triggers (self , defragmenter ):
@@ -214,13 +217,13 @@ def test_defragment_triggers(self, defragmenter):
214217 defragmenter .use_block (i )
215218
216219 defragmenter ._extend_mapping_table (102 )
217- defragmenter .cache_utils = MagicMock ()
220+ defragmenter ._swap = MagicMock ()
218221
219222 defragmenter .defragment ()
220223
221224 # Should call swap with high blocks moved to low positions
222- defragmenter .cache_utils . swap .assert_called_once ()
223- args = defragmenter .cache_utils . swap .call_args [0 ]
225+ defragmenter ._swap .assert_called_once ()
226+ args = defragmenter ._swap .call_args [0 ]
224227 to_swap = args [0 ]
225228 threshold = args [1 ]
226229
@@ -240,62 +243,20 @@ def test_defragment_early_exit(self, defragmenter):
240243 defragmenter .use_block (100 )
241244
242245 defragmenter ._extend_mapping_table (100 )
243- defragmenter .cache_utils = MagicMock ()
246+ defragmenter ._swap = MagicMock ()
244247
245248 defragmenter .defragment ()
246249
247250 # Free blocks: 1, 3, 4, 5...
248251 # Used blocks (descending): 100, 2
249252 # Pair (100, 1): valid swap
250253 # Pair (2, 3): 3 > 2, so break
251- args = defragmenter .cache_utils . swap .call_args [0 ]
254+ args = defragmenter ._swap .call_args [0 ]
252255 to_swap = args [0 ]
253256
254257 assert len (to_swap ) == 1
255258 assert to_swap [0 ] == (100 , 1 )
256259
257-
258- class TestCacheSwapUtils :
259- """Test suite for CacheSwapUtils"""
260-
261- @pytest .fixture
262- def mock_kv_caches (self ):
263- """Create mock KV cache tensors"""
264- num_blocks = 100
265- block_size = 16
266- num_heads = 8
267- head_dim = 64
268- num_layers = 2
269-
270- kv_caches = []
271- for _ in range (num_layers ):
272- k_cache = torch .randn (num_blocks * block_size , num_heads , head_dim )
273- v_cache = torch .randn (num_blocks * block_size , num_heads , head_dim )
274- kv_caches .append ((k_cache , v_cache ))
275- return tuple (kv_caches )
276-
277- @pytest .fixture
278- def swap_utils (self , mock_kv_caches ):
279- """Create CacheSwapUtils instance"""
280- with patch ('vllm_gaudi.extension.defragmentation.htorch' ):
281- return CacheSwapUtils (mock_kv_caches , block_size = 16 )
282-
283- def test_cache_swap_utils_init (self , swap_utils ):
284- """Test CacheSwapUtils initialization"""
285- assert swap_utils .block_size == 16
286- assert len (swap_utils .kv_caches ) == 2
287- assert swap_utils .block_slots .shape == (16 , )
288- assert swap_utils .is_mla is False
289-
290- def test_cache_swap_utils_mla_detection (self ):
291- """Test MLA (multi-layer attention) detection"""
292- # Create MLA-style caches (no value cache)
293- mla_caches = [(torch .randn (100 , 8 , 64 ), None ), (torch .randn (100 , 8 , 64 ), None )]
294-
295- with patch ('vllm_gaudi.extension.defragmentation.htorch' ):
296- utils = CacheSwapUtils (tuple (mla_caches ), block_size = 16 )
297- assert utils .is_mla is True
298-
299260 def test_swap_execution (self ):
300261 """Test swap method execution flow on HPU"""
301262 import habana_frameworks .torch as htorch
@@ -305,14 +266,15 @@ def test_swap_execution(self):
305266 num_heads = 8
306267 head_dim = 64
307268 num_layers = 2
269+ DEVICE = 'hpu'
308270
309271 kv_caches = []
310272 for _ in range (num_layers ):
311- k_cache = torch .randn (num_blocks * block_size , num_heads , head_dim , device = 'hpu' )
312- v_cache = torch .randn (num_blocks * block_size , num_heads , head_dim , device = 'hpu' )
273+ k_cache = torch .randn (num_blocks * block_size , num_heads , head_dim , device = DEVICE )
274+ v_cache = torch .randn (num_blocks * block_size , num_heads , head_dim , device = DEVICE )
313275 kv_caches .append ((k_cache , v_cache ))
314276
315- swap_utils = CacheSwapUtils ( tuple ( kv_caches ) , block_size = 16 )
277+ defragmenter = OnlineDefragmenter ( kv_caches , block_size = block_size )
316278
317279 to_swap = [(10 , 5 ), (20 , 6 )]
318280 threshold = 8
@@ -321,7 +283,7 @@ def test_swap_execution(self):
321283 orig_k_10 = kv_caches [0 ][0 ][10 * block_size :(10 + 1 ) * block_size ].clone ()
322284 orig_k_5 = kv_caches [0 ][0 ][5 * block_size :(5 + 1 ) * block_size ].clone ()
323285
324- swap_utils . swap (to_swap , threshold )
286+ defragmenter . _swap (to_swap , threshold )
325287 htorch .core .mark_step ()
326288
327289 # Verify blocks were swapped
@@ -335,35 +297,69 @@ def test_swap_execution(self):
335297 def test_swap_mla_single_call (self , mock_htorch ):
336298 """Test MLA swap only calls forward once (no value cache)"""
337299 mla_caches = [(torch .randn (100 , 8 , 64 ), None ), (torch .randn (100 , 8 , 64 ), None )]
338- utils = CacheSwapUtils ( tuple ( mla_caches ) , block_size = 16 )
300+ defragmenter = OnlineDefragmenter ( mla_caches , block_size = 16 )
339301
340302 to_swap = [(10 , 5 )]
341303 threshold = 8
342304
343- with patch .object (utils , 'forward' ) as mock_forward :
344- utils . swap (to_swap , threshold )
305+ with patch .object (defragmenter . cache_utils , 'forward' ) as mock_forward :
306+ defragmenter . _swap (to_swap , threshold )
345307
346308 # Should only be called once for keys (no values)
347309 assert mock_forward .call_count == 1
348310
349311
312+ class TestCacheSwapUtils :
313+ """Test suite for CacheSwapUtils"""
314+
315+ @pytest .fixture
316+ def mock_kv_caches (self ):
317+ """Create mock KV cache tensors"""
318+ num_blocks = 100
319+ block_size = 16
320+ num_heads = 8
321+ head_dim = 64
322+ num_layers = 2
323+
324+ kv_caches = []
325+ for _ in range (num_layers ):
326+ k_cache = torch .randn (num_blocks * block_size , num_heads , head_dim )
327+ v_cache = torch .randn (num_blocks * block_size , num_heads , head_dim )
328+ kv_caches .append ((k_cache , v_cache ))
329+ return tuple (kv_caches )
330+
331+ @pytest .fixture
332+ def swap_utils (self , mock_kv_caches ):
333+ """Create CacheSwapUtils instance"""
334+ with patch ('vllm_gaudi.extension.defragmentation.htorch' ):
335+ return CacheSwapUtils (16 , 'hpu' )
336+
337+
350338class TestDefragmentationIntegration :
351339 """Integration tests for defragmentation workflow"""
352340
353341 @pytest .fixture
354342 def setup_defragmenter (self , mock_config , mock_debug_logger ):
355343 """Setup defragmenter with mock caches"""
356- defrag = OnlineDefragmenter ()
357344
358345 # Create simple mock caches
359346 kv_caches = [(torch .zeros (1600 , 8 , 64 ), torch .zeros (1600 , 8 , 64 )),
360347 (torch .zeros (1600 , 8 , 64 ), torch .zeros (1600 , 8 , 64 ))]
361348
362349 with patch ('vllm_gaudi.extension.defragmentation.htorch' ):
363- defrag . initialize (tuple (kv_caches ), block_size = 16 )
350+ defrag = OnlineDefragmenter (tuple (kv_caches ), block_size = 16 )
364351
365352 return defrag
366353
354+ def test_cache_swap_utils_mla_detection (self ):
355+ """Test MLA (multi-layer attention) detection"""
356+ # Create MLA-style caches (no value cache)
357+ mla_caches = [(torch .randn (100 , 8 , 64 ), None ), (torch .randn (100 , 8 , 64 ), None )]
358+
359+ with patch ('vllm_gaudi.extension.defragmentation.htorch' ):
360+ utils = OnlineDefragmenter (tuple (mla_caches ), block_size = 16 )
361+ assert utils .is_mla is True
362+
367363 def test_full_lifecycle (self , setup_defragmenter ):
368364 """Test complete request lifecycle with defragmentation"""
369365 defrag = setup_defragmenter
@@ -385,7 +381,7 @@ def test_full_lifecycle(self, setup_defragmenter):
385381 defrag .update_state ({'req_3' : [100 , 101 , 102 ]}, [])
386382
387383 # Trigger defragmentation
388- with patch .object (defrag . cache_utils , 'swap ' ):
384+ with patch .object (defrag , '_swap ' ):
389385 defrag .defragment ()
390386
391387 def test_mapping_persistence (self , setup_defragmenter ):
@@ -398,7 +394,7 @@ def test_mapping_persistence(self, setup_defragmenter):
398394
399395 defrag ._extend_mapping_table (100 )
400396
401- with patch .object (defrag . cache_utils , 'swap ' ):
397+ with patch .object (defrag , '_swap ' ):
402398 defrag .defragment ()
403399
404400 # Verify mappings exist
0 commit comments