@@ -421,15 +421,15 @@ def register_smooth_transform(module, transform):
421421 register_smooth_transform (o_proj , S2_transform_inv )
422422
423423 for up_proj , down_proj in zip (up_projs , down_projs ):
424- register_offload_module (up_proj , "S4_transform" , S4_transform )
425- register_offload_module (down_proj , "S4_transform_inv" , S4_transform_inv )
426424 S4_transform = SmoothTransform (up_proj .out_features , is_out = True ).to (
427425 torch .cuda .current_device ()
428426 )
429427 S4_transform_inv = SmoothTransform (
430428 down_proj .in_features , is_out = False , inverse = True
431429 ).to (torch .cuda .current_device ())
432430 S4_transform_inv .scale = S4_transform .scale
431+ register_offload_module (up_proj , "S4_transform" , S4_transform )
432+ register_offload_module (down_proj , "S4_transform_inv" , S4_transform_inv )
433433 with align_module_device (up_proj ), align_module_device (down_proj ):
434434 register_smooth_transform (up_proj , S4_transform )
435435 register_smooth_transform (down_proj , S4_transform_inv )
@@ -496,18 +496,18 @@ def register_smooth_transform(module, transform):
496496 S1_transform_inv = SmoothTransform (
497497 mlp_norm .weight .shape [0 ], is_out = True , inverse = True
498498 ).to (torch .cuda .current_device ())
499- S1_transform_inv .scale = S1_transform .scale
500- register_offload_module (mlp_norm , "S1_transform_inv" , S1_transform_inv )
501- register_offload_module (mlp_norm , "S1_transform" , S1_transform )
502- with (
503- torch .no_grad (),
504- align_module_device (mlp_norm ),
505- ):
506- register_smooth_transform (mlp_norm , S1_transform_inv )
507- with torch .no_grad ():
508- for module in gates + ups :
509- with align_module_device (module ):
510- register_smooth_transform (module , S1_transform )
499+ # S1_transform_inv.scale = S1_transform.scale
500+ # register_offload_module(mlp_norm, "S1_transform_inv", S1_transform_inv)
501+ # register_offload_module(mlp_norm, "S1_transform", S1_transform)
502+ # with (
503+ # torch.no_grad(),
504+ # align_module_device(mlp_norm),
505+ # ):
506+ # register_smooth_transform(mlp_norm, S1_transform_inv)
507+ # with torch.no_grad():
508+ # for module in gates + ups:
509+ # with align_module_device(module):
510+ # register_smooth_transform(module, S1_transform)
511511
512512 return state , recipe_ , model
513513
0 commit comments