@@ -193,7 +193,7 @@ def apply_block_swap_to_dit(
193193 runner ._blockswap_active = True
194194
195195 # Store configuration for debugging and cleanup
196- runner ._block_swap_config = {
196+ model ._block_swap_config = {
197197 "blocks_swapped" : blocks_to_swap ,
198198 "swap_io_components" : swap_io_components ,
199199 "total_blocks" : total_blocks ,
@@ -650,11 +650,11 @@ def _protect_model_from_move(
650650
651651 Wraps model.to() method to prevent other code from accidentally moving
652652 the entire model to GPU, which would defeat BlockSwap's memory savings.
653- Allows movement only when explicitly bypassed via runner flag.
653+ Allows movement only when explicitly bypassed via model flag.
654654
655655 Args:
656656 model: DiT model to protect
657- runner: VideoDiffusionInfer instance (stored as weak reference )
657+ runner: VideoDiffusionInfer instance (for active status check )
658658 debug: Debug instance for logging (required)
659659 """
660660 if not hasattr (model , '_original_to' ):
@@ -665,34 +665,46 @@ def _protect_model_from_move(
665665 # Define the protected method without closures
666666 def protected_model_to (self , device , * args , ** kwargs ):
667667 # Check if protection is temporarily bypassed for offloading
668+ # Flag is stored on model itself (not runner) to survive runner recreation
669+ if getattr (self , "_blockswap_bypass_protection" , False ):
670+ # Protection bypassed, allow movement
671+ if hasattr (self , '_original_to' ):
672+ return self ._original_to (device , * args , ** kwargs )
673+
674+ # Get configured offload device directly from model
675+ blockswap_offload_device = "cpu" # default
676+ if hasattr (self , "_block_swap_config" ):
677+ blockswap_offload_device = self ._block_swap_config .get ("offload_device" , "cpu" )
678+
679+ # Check if BlockSwap is currently active via runner weak reference
668680 runner_ref = getattr (self , '_blockswap_runner_ref' , None )
681+ blockswap_is_active = False
669682 if runner_ref :
670683 runner_obj = runner_ref ()
671- if runner_obj and getattr (runner_obj , "_blockswap_bypass_protection" , False ):
672- # Protection bypassed, allow movement
673- if hasattr (self , '_original_to' ):
674- return self ._original_to (device , * args , ** kwargs )
684+ if runner_obj and hasattr (runner_obj , "_blockswap_active" ):
685+ blockswap_is_active = runner_obj ._blockswap_active
675686
676- # Check blockswap status using weak reference
677- # Get configured offload device from runner
678- blockswap_offload_device = "cpu" # default
679- if runner_ref :
680- runner_obj = runner_ref ()
681- if runner_obj and hasattr (runner_obj , "_block_swap_config" ):
682- blockswap_offload_device = runner_obj ._block_swap_config .get ("offload_device" , "cpu" )
687+ # Block attempts to move model away from configured offload device when active
688+ if blockswap_is_active and str (device ) != str (blockswap_offload_device ):
689+ # Get debug instance from runner if available
690+ debug_instance = None
691+ if runner_ref :
692+ runner_obj = runner_ref ()
693+ if runner_obj and hasattr (runner_obj , 'debug' ):
694+ debug_instance = runner_obj .debug
683695
684- # Block attempts to move model away from configured offload device
685- if str ( device ) != str ( blockswap_offload_device ):
686- if runner_obj and hasattr ( runner_obj , "_blockswap_active" ) and runner_obj . _blockswap_active :
687- debug . log ( f"Blocked attempt to move blockswapped model from { blockswap_offload_device } to { device } " ,
688- level = "WARNING" , category = "blockswap" , force = True )
689- return self
696+ if debug_instance :
697+ debug_instance . log (
698+ f"Blocked attempt to move BlockSwap model from { blockswap_offload_device } to { device } " ,
699+ level = "WARNING" , category = "blockswap" , force = True
700+ )
701+ return self
690702
691- # Use original method stored as attribute
703+ # Allow movement (either bypass is enabled or target is offload device)
692704 if hasattr (self , '_original_to' ):
693705 return self ._original_to (device , * args , ** kwargs )
694706 else :
695- # This shouldn't happen, but fallback to super().to()
707+ # Fallback - shouldn't happen
696708 return super (type (self ), self ).to (device , * args , ** kwargs )
697709
698710 # Bind as a method to the model instance
@@ -712,7 +724,13 @@ def set_blockswap_bypass(runner, bypass: bool, debug):
712724 if not hasattr (runner , "_blockswap_active" ) or not runner ._blockswap_active :
713725 return
714726
715- runner ._blockswap_bypass_protection = bypass
727+ # Get the actual model (handle FP8CompatibleDiT wrapper)
728+ model = runner .dit
729+ if hasattr (model , "dit_model" ):
730+ model = model .dit_model
731+
732+ # Store on model so it survives runner recreation during caching
733+ model ._blockswap_bypass_protection = bypass
716734
717735 if bypass :
718736 debug .log ("BlockSwap protection disabled to allow model DiT offloading" , category = "success" )
@@ -741,11 +759,16 @@ def cleanup_blockswap(runner, keep_state_for_cache=False):
741759
742760 debug = runner .debug
743761
744- # Check if there's any BlockSwap state to clean up
762+ # Get the actual model (handle FP8CompatibleDiT wrapper)
763+ model = runner .dit
764+ if hasattr (model , "dit_model" ):
765+ model = model .dit_model
766+
767+ # Check if there's any BlockSwap state to clean up (check both runner and model)
745768 has_blockswap_state = (
746769 hasattr (runner , "_blockswap_active" ) or
747- hasattr (runner , "_block_swap_config" ) or
748- hasattr (runner , "_blockswap_bypass_protection" )
770+ hasattr (model , "_block_swap_config" ) or
771+ hasattr (model , "_blockswap_bypass_protection" )
749772 )
750773
751774 if not has_blockswap_state :
@@ -757,7 +780,7 @@ def cleanup_blockswap(runner, keep_state_for_cache=False):
757780 # Minimal cleanup for caching - just mark as inactive and allow offloading
758781 # Everything else stays intact for fast reactivation
759782 if hasattr (runner , "_blockswap_active" ) and runner ._blockswap_active :
760- if not getattr (runner , "_blockswap_bypass_protection" , False ):
783+ if not getattr (model , "_blockswap_bypass_protection" , False ):
761784 set_blockswap_bypass (runner = runner , bypass = True , debug = debug )
762785 runner ._blockswap_active = False
763786 debug .log ("BlockSwap deactivated for caching (configuration preserved)" , category = "success" )
@@ -829,7 +852,7 @@ def cleanup_blockswap(runner, keep_state_for_cache=False):
829852
830853 # 5. Clean up BlockSwap-specific attributes
831854 for attr in ['_blockswap_runner_ref' , 'blocks_to_swap' , 'main_device' ,
832- 'offload_device' , '_blockswap_configured' ]:
855+ 'offload_device' ]:
833856 if hasattr (model , attr ):
834857 delattr (model , attr )
835858
0 commit comments