Skip to content

Commit 49ba149

Browse files
[GAUDISW-244575] Reapply Adapt OnlineDefragmenter and CacheSwapUtils for t.compile
Because of the double entrypoint of CacheSwapUtils (forward and swap functions) torch.compile would process module and forward function while swap's self would refer to unwrapped module. That results in the function not being run as compiled Changes made in this patch: - Hide CacheSwapUtils entirely in OnlineDefragmenter. Let it be responsible for calling the module correctly - Moved warmup_defragmenter to defragmenter itself - Removed initialize function of OnlineDefragmenter, fully initialize object in init - Adapted unit tests for new implementations Signed-off-by: Jan Wieczorek <jwieczorek@habana.ai>
1 parent f5d8681 commit 49ba149

File tree

4 files changed

+169
-165
lines changed

4 files changed

+169
-165
lines changed

tests/unit_tests/test_defragmentation.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def mock_debug_logger():
3636
@pytest.fixture
3737
def defragmenter(mock_config, mock_debug_logger):
3838
"""Create OnlineDefragmenter instance"""
39-
return OnlineDefragmenter()
39+
kv_caches = ((torch.empty(0, device='meta'), torch.empty(0, device='meta')), )
40+
return OnlineDefragmenter(kv_caches, block_size=0)
4041

4142

4243
class TestOnlineDefragmenter:
@@ -174,7 +175,9 @@ def test_free_blocks_generator(self, defragmenter):
174175
def test_defragment_disabled(self, mock_config, mock_debug_logger):
175176
"""Test defragmentation when disabled"""
176177
mock_config.defrag = False
177-
defrag = OnlineDefragmenter()
178+
179+
kv_caches = ((torch.empty(0, device='meta'), torch.empty(0, device='meta')), )
180+
defrag = OnlineDefragmenter(kv_caches, 0)
178181

179182
defrag.use_block(100)
180183
defrag.defragment()
@@ -198,11 +201,11 @@ def test_defragment_below_threshold(self, defragmenter):
198201

199202
max_before = max(defragmenter.used_blocks.keys())
200203
defragmenter._extend_mapping_table(max_before)
201-
defragmenter.cache_utils = MagicMock()
204+
defragmenter._swap = MagicMock()
202205
defragmenter.defragment()
203206

204207
# Should not trigger defragmentation
205-
defragmenter.cache_utils.swap.assert_not_called()
208+
defragmenter._swap.assert_not_called()
206209
assert max(defragmenter.used_blocks.keys()) == max_before
207210

208211
def test_defragment_triggers(self, defragmenter):
@@ -214,13 +217,13 @@ def test_defragment_triggers(self, defragmenter):
214217
defragmenter.use_block(i)
215218

216219
defragmenter._extend_mapping_table(102)
217-
defragmenter.cache_utils = MagicMock()
220+
defragmenter._swap = MagicMock()
218221

219222
defragmenter.defragment()
220223

221224
# Should call swap with high blocks moved to low positions
222-
defragmenter.cache_utils.swap.assert_called_once()
223-
args = defragmenter.cache_utils.swap.call_args[0]
225+
defragmenter._swap.assert_called_once()
226+
args = defragmenter._swap.call_args[0]
224227
to_swap = args[0]
225228
threshold = args[1]
226229

@@ -240,62 +243,20 @@ def test_defragment_early_exit(self, defragmenter):
240243
defragmenter.use_block(100)
241244

242245
defragmenter._extend_mapping_table(100)
243-
defragmenter.cache_utils = MagicMock()
246+
defragmenter._swap = MagicMock()
244247

245248
defragmenter.defragment()
246249

247250
# Free blocks: 1, 3, 4, 5...
248251
# Used blocks (descending): 100, 2
249252
# Pair (100, 1): valid swap
250253
# Pair (2, 3): 3 > 2, so break
251-
args = defragmenter.cache_utils.swap.call_args[0]
254+
args = defragmenter._swap.call_args[0]
252255
to_swap = args[0]
253256

254257
assert len(to_swap) == 1
255258
assert to_swap[0] == (100, 1)
256259

257-
258-
class TestCacheSwapUtils:
259-
"""Test suite for CacheSwapUtils"""
260-
261-
@pytest.fixture
262-
def mock_kv_caches(self):
263-
"""Create mock KV cache tensors"""
264-
num_blocks = 100
265-
block_size = 16
266-
num_heads = 8
267-
head_dim = 64
268-
num_layers = 2
269-
270-
kv_caches = []
271-
for _ in range(num_layers):
272-
k_cache = torch.randn(num_blocks * block_size, num_heads, head_dim)
273-
v_cache = torch.randn(num_blocks * block_size, num_heads, head_dim)
274-
kv_caches.append((k_cache, v_cache))
275-
return tuple(kv_caches)
276-
277-
@pytest.fixture
278-
def swap_utils(self, mock_kv_caches):
279-
"""Create CacheSwapUtils instance"""
280-
with patch('vllm_gaudi.extension.defragmentation.htorch'):
281-
return CacheSwapUtils(mock_kv_caches, block_size=16)
282-
283-
def test_cache_swap_utils_init(self, swap_utils):
284-
"""Test CacheSwapUtils initialization"""
285-
assert swap_utils.block_size == 16
286-
assert len(swap_utils.kv_caches) == 2
287-
assert swap_utils.block_slots.shape == (16, )
288-
assert swap_utils.is_mla is False
289-
290-
def test_cache_swap_utils_mla_detection(self):
291-
"""Test MLA (multi-layer attention) detection"""
292-
# Create MLA-style caches (no value cache)
293-
mla_caches = [(torch.randn(100, 8, 64), None), (torch.randn(100, 8, 64), None)]
294-
295-
with patch('vllm_gaudi.extension.defragmentation.htorch'):
296-
utils = CacheSwapUtils(tuple(mla_caches), block_size=16)
297-
assert utils.is_mla is True
298-
299260
def test_swap_execution(self):
300261
"""Test swap method execution flow on HPU"""
301262
import habana_frameworks.torch as htorch
@@ -305,14 +266,15 @@ def test_swap_execution(self):
305266
num_heads = 8
306267
head_dim = 64
307268
num_layers = 2
269+
DEVICE = 'hpu'
308270

309271
kv_caches = []
310272
for _ in range(num_layers):
311-
k_cache = torch.randn(num_blocks * block_size, num_heads, head_dim, device='hpu')
312-
v_cache = torch.randn(num_blocks * block_size, num_heads, head_dim, device='hpu')
273+
k_cache = torch.randn(num_blocks * block_size, num_heads, head_dim, device=DEVICE)
274+
v_cache = torch.randn(num_blocks * block_size, num_heads, head_dim, device=DEVICE)
313275
kv_caches.append((k_cache, v_cache))
314276

315-
swap_utils = CacheSwapUtils(tuple(kv_caches), block_size=16)
277+
defragmenter = OnlineDefragmenter(kv_caches, block_size=block_size)
316278

317279
to_swap = [(10, 5), (20, 6)]
318280
threshold = 8
@@ -321,7 +283,7 @@ def test_swap_execution(self):
321283
orig_k_10 = kv_caches[0][0][10 * block_size:(10 + 1) * block_size].clone()
322284
orig_k_5 = kv_caches[0][0][5 * block_size:(5 + 1) * block_size].clone()
323285

324-
swap_utils.swap(to_swap, threshold)
286+
defragmenter._swap(to_swap, threshold)
325287
htorch.core.mark_step()
326288

327289
# Verify blocks were swapped
@@ -335,35 +297,69 @@ def test_swap_execution(self):
335297
def test_swap_mla_single_call(self, mock_htorch):
336298
"""Test MLA swap only calls forward once (no value cache)"""
337299
mla_caches = [(torch.randn(100, 8, 64), None), (torch.randn(100, 8, 64), None)]
338-
utils = CacheSwapUtils(tuple(mla_caches), block_size=16)
300+
defragmenter = OnlineDefragmenter(mla_caches, block_size=16)
339301

340302
to_swap = [(10, 5)]
341303
threshold = 8
342304

343-
with patch.object(utils, 'forward') as mock_forward:
344-
utils.swap(to_swap, threshold)
305+
with patch.object(defragmenter.cache_utils, 'forward') as mock_forward:
306+
defragmenter._swap(to_swap, threshold)
345307

346308
# Should only be called once for keys (no values)
347309
assert mock_forward.call_count == 1
348310

349311

312+
class TestCacheSwapUtils:
313+
"""Test suite for CacheSwapUtils"""
314+
315+
@pytest.fixture
316+
def mock_kv_caches(self):
317+
"""Create mock KV cache tensors"""
318+
num_blocks = 100
319+
block_size = 16
320+
num_heads = 8
321+
head_dim = 64
322+
num_layers = 2
323+
324+
kv_caches = []
325+
for _ in range(num_layers):
326+
k_cache = torch.randn(num_blocks * block_size, num_heads, head_dim)
327+
v_cache = torch.randn(num_blocks * block_size, num_heads, head_dim)
328+
kv_caches.append((k_cache, v_cache))
329+
return tuple(kv_caches)
330+
331+
@pytest.fixture
332+
def swap_utils(self, mock_kv_caches):
333+
"""Create CacheSwapUtils instance"""
334+
with patch('vllm_gaudi.extension.defragmentation.htorch'):
335+
return CacheSwapUtils(16, 'hpu')
336+
337+
350338
class TestDefragmentationIntegration:
351339
"""Integration tests for defragmentation workflow"""
352340

353341
@pytest.fixture
354342
def setup_defragmenter(self, mock_config, mock_debug_logger):
355343
"""Setup defragmenter with mock caches"""
356-
defrag = OnlineDefragmenter()
357344

358345
# Create simple mock caches
359346
kv_caches = [(torch.zeros(1600, 8, 64), torch.zeros(1600, 8, 64)),
360347
(torch.zeros(1600, 8, 64), torch.zeros(1600, 8, 64))]
361348

362349
with patch('vllm_gaudi.extension.defragmentation.htorch'):
363-
defrag.initialize(tuple(kv_caches), block_size=16)
350+
defrag = OnlineDefragmenter(tuple(kv_caches), block_size=16)
364351

365352
return defrag
366353

354+
def test_cache_swap_utils_mla_detection(self):
355+
"""Test MLA (multi-layer attention) detection"""
356+
# Create MLA-style caches (no value cache)
357+
mla_caches = [(torch.randn(100, 8, 64), None), (torch.randn(100, 8, 64), None)]
358+
359+
with patch('vllm_gaudi.extension.defragmentation.htorch'):
360+
utils = OnlineDefragmenter(tuple(mla_caches), block_size=16)
361+
assert utils.is_mla is True
362+
367363
def test_full_lifecycle(self, setup_defragmenter):
368364
"""Test complete request lifecycle with defragmentation"""
369365
defrag = setup_defragmenter
@@ -385,7 +381,7 @@ def test_full_lifecycle(self, setup_defragmenter):
385381
defrag.update_state({'req_3': [100, 101, 102]}, [])
386382

387383
# Trigger defragmentation
388-
with patch.object(defrag.cache_utils, 'swap'):
384+
with patch.object(defrag, '_swap'):
389385
defrag.defragment()
390386

391387
def test_mapping_persistence(self, setup_defragmenter):
@@ -398,7 +394,7 @@ def test_mapping_persistence(self, setup_defragmenter):
398394

399395
defrag._extend_mapping_table(100)
400396

401-
with patch.object(defrag.cache_utils, 'swap'):
397+
with patch.object(defrag, '_swap'):
402398
defrag.defragment()
403399

404400
# Verify mappings exist

0 commit comments

Comments
 (0)