11import torch
22import torch .distributed as dist
33from cache_dit .envs import ENV
4- from cache_dit .logger import init_logger , logging_rank_0
4+ from cache_dit .logger import init_logger
55
66logger = init_logger (__name__ )
77
88
9- def epilogue_prologue_fusion_enabled (** kwargs ) -> bool :
10- mode = kwargs .get ("epilogue_prologue_fusion" , False )
11-
12- if ENV .CACHE_DIT_EPILOGUE_PROLOGUE_FUSION :
13- logging_rank_0 (
14- logger ,
15- "CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n "
16- "Force enable epilogue and prologue fusion." ,
17- )
18-
19- return ENV .CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
20-
21-
22- def enable_compile_compute_comm_overlap ():
23- ENV .CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = True
24- logger .info ("Enabled compile compute-communication overlap manually." )
25-
26-
27- def disable_compile_compute_comm_overlap ():
28- ENV .CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = False
29- logger .info ("Disabled compile compute-communication overlap manually." )
30-
31-
32- def is_compile_compute_comm_overlap_enabled () -> bool :
33- return ENV .CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
34-
35-
369def set_compile_configs (
3710 descent_tuning : bool = False ,
3811 cuda_graphs : bool = False ,
@@ -56,7 +29,7 @@ def set_compile_configs(
5629 if dist .is_initialized ():
5730 # Enable compute comm overlap
5831 torch ._inductor .config .reorder_for_compute_comm_overlap = (
59- compute_comm_overlap and is_compile_compute_comm_overlap_enabled ()
32+ compute_comm_overlap and ENV . CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
6033 )
6134 # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
6235 if torch ._inductor .config .reorder_for_compute_comm_overlap :
@@ -73,8 +46,7 @@ def set_compile_configs(
7346 return
7447
7548 if ENV .CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG :
76- logging_rank_0 (
77- logger ,
49+ logger .info (
7850 "CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n "
7951 "Force disable custom compile config." ,
8052 )
@@ -95,7 +67,7 @@ def set_compile_configs(
9567 torch ._inductor .config .epilogue_fusion = False
9668
9769 # Enable epilogue and prologue fusion
98- if epilogue_prologue_fusion_enabled ( ** kwargs ):
70+ if ENV . CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs . get ( "epilogue_prologue_fusion" , False ):
9971 torch ._inductor .config .epilogue_fusion = True
10072 torch ._inductor .config .prologue_fusion = True
10173 torch ._inductor .config .epilogue_fusion_first = True
0 commit comments