@@ -77,43 +77,40 @@ def input_fn():
77
77
mp_policy = mp_policy ,
78
78
compile = job_config .training .compile ,
79
79
) as autop :
80
- # currently errors due to missing sharding prop rules
81
- torch .distributed .breakpoint ()
82
-
83
- # autop.add_parameter_memory_constraint(low=None, high=None)
84
-
85
- # possible_input_shardings = {
86
- # # maps relative to mesh dim names used in torchtitan
87
- # "dp_replicate": Shard(0),
88
- # "dp_shard": Shard(0),
89
- # "tp": Replicate(),
90
- # }
91
- # # only used if loss parallel is enabled
92
- # possible_output_shardings = {
93
- # # maps relative to mesh dim names used in torchtitan
94
- # "dp_shard": Shard(0),
95
- # "tp": Shard(2),
96
- # }
97
- # assert all(
98
- # name in possible_input_shardings for name in world_mesh.mesh_dim_names
99
- # ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
100
- # x_sharding = tuple(
101
- # possible_input_shardings[name] for name in world_mesh.mesh_dim_names
102
- # )
103
- # out_sharding = x_sharding
104
- # if parallel_dims.loss_parallel_enabled:
105
- # out_sharding = tuple(
106
- # possible_output_shardings[name]
107
- # for name in world_mesh.mesh_dim_names
108
- # if name != "dp_replicate"
109
- # )
110
- # autop.add_input_constraints([x_sharding])
111
- # autop.add_output_constraints([out_sharding])
112
- # t0 = time.time()
113
- # sharding_placement = autop.optimize_placement()
114
- # t1 = time.time()
115
- # logger.info(f"AutoParallel took {t1 - t0} seconds")
116
- # parallel_mod = autop.apply_placement(sharding_placement)
80
+ autop .add_parameter_memory_constraint (low = None , high = None )
81
+
82
+ possible_input_shardings = {
83
+ # maps relative to mesh dim names used in torchtitan
84
+ "dp_replicate" : Shard (0 ),
85
+ "dp_shard" : Shard (0 ),
86
+ "tp" : Replicate (),
87
+ }
88
+ # only used if loss parallel is enabled
89
+ possible_output_shardings = {
90
+ # maps relative to mesh dim names used in torchtitan
91
+ "dp_shard" : Shard (0 ),
92
+ "tp" : Shard (2 ),
93
+ }
94
+ assert all (
95
+ name in possible_input_shardings for name in world_mesh .mesh_dim_names
96
+ ), f"Unsupported mesh dim in world mesh, only { possible_input_shardings .keys ()} are supported by AutoParallel"
97
+ x_sharding = tuple (
98
+ possible_input_shardings [name ] for name in world_mesh .mesh_dim_names
99
+ )
100
+ out_sharding = x_sharding
101
+ if parallel_dims .loss_parallel_enabled :
102
+ out_sharding = tuple (
103
+ possible_output_shardings [name ]
104
+ for name in world_mesh .mesh_dim_names
105
+ if name != "dp_replicate"
106
+ )
107
+ autop .add_input_constraints ([x_sharding ])
108
+ autop .add_output_constraints ([out_sharding ])
109
+ t0 = time .time ()
110
+ sharding_placement = autop .optimize_placement ()
111
+ t1 = time .time ()
112
+ logger .info (f"AutoParallel took { t1 - t0 } seconds" )
113
+ parallel_mod = autop .apply_placement (sharding_placement )
117
114
118
115
if parallel_dims .loss_parallel_enabled :
119
116
0 commit comments