@@ -74,8 +74,6 @@ def default_activation_checkpoint_policy() -> _PolicyFn:
7474 torch ._higher_order_ops .inductor_compiled_code
7575 ] = CheckpointPolicy .MUST_SAVE
7676
77- compute_intensive_ops [torch .ops .aten .max .default ] = CheckpointPolicy .MUST_SAVE
78-
7977 if hasattr (torch .ops , "torch_attn" ) and hasattr (
8078 torch .ops .torch_attn , "_varlen_attn"
8179 ):
@@ -88,23 +86,6 @@ def default_activation_checkpoint_policy() -> _PolicyFn:
8886 torch .ops ._c10d_functional .all_to_all_single .default : CheckpointPolicy .MUST_SAVE ,
8987 }
9088
91- # DeepEP ops for MoE expert parallelism
92- # Try to import deepep module to register custom ops, then check if they exist
93- try :
94- import torchtitan .distributed .deepep # noqa: F401 - registers torch.ops.deepep
95-
96- if hasattr (torch .ops , "deepep" ):
97- if hasattr (torch .ops .deepep , "dispatch" ):
98- communication_intensive_ops [
99- torch .ops .deepep .dispatch .default
100- ] = CheckpointPolicy .MUST_SAVE
101- if hasattr (torch .ops .deepep , "combine" ):
102- communication_intensive_ops [
103- torch .ops .deepep .combine .default
104- ] = CheckpointPolicy .MUST_SAVE
105- except ImportError :
106- pass # DeepEP not available
107-
10889 policy_fn = partial (
10990 _sac_policy_fn ,
11091 compute_intensive_ops = compute_intensive_ops ,
0 commit comments