@@ -1352,27 +1352,34 @@ def defrag_cache(self):
13521352 if not self .paged :
13531353 return
13541354
1355+ # Defragment once job queue is empty after touching all the cache pages
13551356 if self .access_serial < self .last_defrag_serial + self .max_pages :
13561357 return
13571358 self .last_defrag_serial = self .access_serial
13581359
13591360 assert not self .referenced_pages
13601361
1361- @dataclass
13621362 class CacheNode :
13631363 page : CachePage | None
1364- parent : CachePage | None = None
1365- children : set [CacheNode ] = None
1366- left_page : int = len (self .all_pages )
1364+ parent : CacheNode | None
1365+ children : set [CacheNode ] | None
1366+ children_sorted : deque [CacheNode ] | None
1367+ left_page : int = 0
13671368 def __init__ (self , page_ ):
13681369 self .page = page_
1369- if self .page :
1370- self .left_page = page_ .access_serial
1370+ self .parent = None
13711371 self .children = set ()
1372+ self .children_sorted = None
1373+ self .left_page = page_ .access_serial if page_ else 0
13721374 def __hash__ (self ):
13731375 return id (self )
13741376 def __eq__ (self , other ):
13751377 return self is other
1378+ def presort (self , recursive = True ):
1379+ self .children_sorted = deque (sorted (self .children , key = lambda x : x .left_page ))
1380+ if recursive :
1381+ for c in self .children :
1382+ c .presort ()
13761383
13771384 # Build a tree of the current cache
13781385
@@ -1393,28 +1400,50 @@ def __eq__(self, other):
13931400
13941401 # Remove oldest branch until tree is empty
13951402
1403+ root_node .presort ()
1404+ shift_counts = {}
1405+
13961406 new_page_index = 0
13971407 while root_node .children :
1398- oldest = min ( root_node .children , key = lambda x : x . left_page )
1408+ oldest = root_node .children_sorted [ 0 ]
13991409 node = oldest
14001410 skipped_nodes = set ()
14011411 while True :
14021412 node .page .new_page_index = new_page_index
1413+ shift = node .page .new_page_index - node .page .page_index
1414+ if shift in shift_counts :
1415+ shift_counts [shift ] += 1
1416+ else :
1417+ shift_counts [shift ] = 1
14031418 new_page_index += 1
14041419 if not node .children : break
1405- next_node = min (node .children , key = lambda x : x .left_page )
1406- skipped_nodes |= set ([n for n in node .children if n != next_node ])
1420+ next_node = node .children_sorted [0 ]
1421+ if len (node .children_sorted ) > 1 :
1422+ skipped_nodes |= set ([n for n in node .children if n != next_node ])
14071423 node = next_node
14081424 root_node .children .remove (oldest )
1425+ root_node .children_sorted .popleft ()
14091426 root_node .children |= skipped_nodes
1427+ if len (skipped_nodes ):
1428+ root_node .presort (False )
1429+
1430+ # Adjust overall shift to minimize page copies
1431+
1432+ shift_adjust = max (shift_counts , key = shift_counts .get )
14101433
14111434 # Order of operations
14121435
14131436 defrag_map = {}
14141437 for page in self .all_pages :
1438+ page .new_page_index = (page .new_page_index - shift_adjust + self .max_pages ) % self .max_pages
14151439 if page .page_index != page .new_page_index :
14161440 defrag_map [page .new_page_index ] = page .page_index
14171441
1442+ # Don't bother if less than 10% of cache is fragmented
1443+
1444+ if len (defrag_map ) <= self .max_pages // 10 :
1445+ return
1446+
14181447 # Shuffle pages
14191448
14201449 cache_tensors = self .cache .all_tensors ()
@@ -1435,12 +1464,11 @@ def __eq__(self, other):
14351464 source = defrag_map [target ]
14361465 del defrag_map [target ]
14371466
1438- rotation = [ r * self . page_size for r in rotation ]
1467+ rotation = torch . tensor ( rotation , dtype = torch . int )
14391468 for cache , buffer in zip (cache_tensors , defrag_buffers ):
1440- buffer [:, :, :, :].copy_ (cache [:, rotation [0 ] : rotation [0 ] + self .page_size , :, :])
1441- for a , b in pairwise (rotation ):
1442- cache [:, a : a + self .page_size , :, :].copy_ (cache [:, b : b + self .page_size , :, :])
1443- cache [:, rotation [- 1 ] : rotation [- 1 ] + self .page_size , :, :].copy_ (buffer [:, :, :, :])
1469+ rotation = rotation .to (cache .device )
1470+ cache = cache .view (cache .shape [1 ] // self .page_size , - 1 )
1471+ ext_c .cache_rotate (cache , rotation , buffer )
14441472
14451473 # Update page table
14461474
0 commit comments