Skip to content

Commit 21936c6

Browse files
authored
chore: refactor kernels module (#618)
* chore: refactor kernels module * chore: refactor kernels module * chore: refactor kernels module
1 parent 226ba52 commit 21936c6

File tree

7 files changed

+7
-36
lines changed

7 files changed

+7
-36
lines changed

src/cache_dit/compile/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
11
from cache_dit.compile.utils import set_compile_configs
2-
from cache_dit.compile.utils import enable_compile_compute_comm_overlap
3-
from cache_dit.compile.utils import disable_compile_compute_comm_overlap
4-
from cache_dit.compile.utils import is_compile_compute_comm_overlap_enabled

src/cache_dit/compile/utils.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,11 @@
11
import torch
22
import torch.distributed as dist
33
from cache_dit.envs import ENV
4-
from cache_dit.logger import init_logger, logging_rank_0
4+
from cache_dit.logger import init_logger
55

66
logger = init_logger(__name__)
77

88

9-
def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
10-
mode = kwargs.get("epilogue_prologue_fusion", False)
11-
12-
if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION:
13-
logging_rank_0(
14-
logger,
15-
"CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n"
16-
"Force enable epilogue and prologue fusion.",
17-
)
18-
19-
return ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
20-
21-
22-
def enable_compile_compute_comm_overlap():
23-
ENV.CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = True
24-
logger.info("Enabled compile compute-communication overlap manually.")
25-
26-
27-
def disable_compile_compute_comm_overlap():
28-
ENV.CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = False
29-
logger.info("Disabled compile compute-communication overlap manually.")
30-
31-
32-
def is_compile_compute_comm_overlap_enabled() -> bool:
33-
return ENV.CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
34-
35-
369
def set_compile_configs(
3710
descent_tuning: bool = False,
3811
cuda_graphs: bool = False,
@@ -56,7 +29,7 @@ def set_compile_configs(
5629
if dist.is_initialized():
5730
# Enable compute comm overlap
5831
torch._inductor.config.reorder_for_compute_comm_overlap = (
59-
compute_comm_overlap and is_compile_compute_comm_overlap_enabled()
32+
compute_comm_overlap and ENV.CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
6033
)
6134
# L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
6235
if torch._inductor.config.reorder_for_compute_comm_overlap:
@@ -73,8 +46,7 @@ def set_compile_configs(
7346
return
7447

7548
if ENV.CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
76-
logging_rank_0(
77-
logger,
49+
logger.info(
7850
"CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n"
7951
"Force disable custom compile config.",
8052
)
@@ -95,7 +67,7 @@ def set_compile_configs(
9567
torch._inductor.config.epilogue_fusion = False
9668

9769
# Enable epilogue and prologue fusion
98-
if epilogue_prologue_fusion_enabled(**kwargs):
70+
if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get("epilogue_prologue_fusion", False):
9971
torch._inductor.config.epilogue_fusion = True
10072
torch._inductor.config.prologue_fusion = True
10173
torch._inductor.config.epilogue_fusion_first = True

src/cache_dit/kernels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from cache_dit.kernels.triton_per_token_quant_8bit import per_token_quant_fp8, per_token_dequant_fp8
1+
from .triton import per_token_quant_fp8, per_token_dequant_fp8
File renamed without changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .per_token_quant_8bit import per_token_quant_fp8
2+
from .per_token_quant_8bit import per_token_dequant_fp8
File renamed without changes.

src/cache_dit/kernels/triton/taylorseer.py

Whitespace-only changes.

0 commit comments

Comments
 (0)