Skip to content

Commit 43558ed

Browse files
committed
.
Signed-off-by: LeiZhang <isleizhang@outlook.com>
1 parent 2614e22 commit 43558ed

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

examples/omnitalker/talker_ostquant.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,15 +421,15 @@ def register_smooth_transform(module, transform):
421421
register_smooth_transform(o_proj, S2_transform_inv)
422422

423423
for up_proj, down_proj in zip(up_projs, down_projs):
424-
register_offload_module(up_proj, "S4_transform", S4_transform)
425-
register_offload_module(down_proj, "S4_transform_inv", S4_transform_inv)
426424
S4_transform = SmoothTransform(up_proj.out_features, is_out=True).to(
427425
torch.cuda.current_device()
428426
)
429427
S4_transform_inv = SmoothTransform(
430428
down_proj.in_features, is_out=False, inverse=True
431429
).to(torch.cuda.current_device())
432430
S4_transform_inv.scale = S4_transform.scale
431+
register_offload_module(up_proj, "S4_transform", S4_transform)
432+
register_offload_module(down_proj, "S4_transform_inv", S4_transform_inv)
433433
with align_module_device(up_proj), align_module_device(down_proj):
434434
register_smooth_transform(up_proj, S4_transform)
435435
register_smooth_transform(down_proj, S4_transform_inv)
@@ -496,18 +496,18 @@ def register_smooth_transform(module, transform):
496496
S1_transform_inv = SmoothTransform(
497497
mlp_norm.weight.shape[0], is_out=True, inverse=True
498498
).to(torch.cuda.current_device())
499-
S1_transform_inv.scale = S1_transform.scale
500-
register_offload_module(mlp_norm, "S1_transform_inv", S1_transform_inv)
501-
register_offload_module(mlp_norm, "S1_transform", S1_transform)
502-
with (
503-
torch.no_grad(),
504-
align_module_device(mlp_norm),
505-
):
506-
register_smooth_transform(mlp_norm, S1_transform_inv)
507-
with torch.no_grad():
508-
for module in gates + ups:
509-
with align_module_device(module):
510-
register_smooth_transform(module, S1_transform)
499+
# S1_transform_inv.scale = S1_transform.scale
500+
# register_offload_module(mlp_norm, "S1_transform_inv", S1_transform_inv)
501+
# register_offload_module(mlp_norm, "S1_transform", S1_transform)
502+
# with (
503+
# torch.no_grad(),
504+
# align_module_device(mlp_norm),
505+
# ):
506+
# register_smooth_transform(mlp_norm, S1_transform_inv)
507+
# with torch.no_grad():
508+
# for module in gates + ups:
509+
# with align_module_device(module):
510+
# register_smooth_transform(module, S1_transform)
511511

512512
return state, recipe_, model
513513

0 commit comments

Comments
 (0)