@@ -13,18 +13,42 @@ namespace FlexFlow {
1313std::vector<Substitution>
1414 get_substitution_set (MachineSpecification const &resources) {
1515 std::vector<Substitution> substitutions;
16- for (nonnegative_int num_dims :
16+ for (nonnegative_int dim :
1717 nonnegative_range (1_n, nonnegative_int{MAX_TENSOR_DIM})) {
1818 for (nonnegative_int degree = 1_n; degree <= get_num_gpus (resources);
1919 degree *= 2_n) {
2020 substitutions.push_back (
21- create_replicate_linear_combine (num_dims , degree, true ));
21+ create_replicate_linear_combine (dim , degree, true ));
2222 substitutions.push_back (
23- create_replicate_linear_combine (num_dims , degree, false ));
23+ create_replicate_linear_combine (dim , degree, false ));
2424 substitutions.push_back (
25- create_partition_linear_combine (num_dims , degree, true ));
25+ create_partition_linear_combine (dim , degree, true ));
2626 substitutions.push_back (
27- create_partition_linear_combine (num_dims, degree, false ));
27+ create_partition_linear_combine (dim, degree, false ));
28+ substitutions.push_back (
29+ create_partition_relu_combine (ff_dim_t {dim}, degree));
30+ substitutions.push_back (
31+ create_partition_add_combine (ff_dim_t {dim}, degree));
32+ substitutions.push_back (create_partition_attention_combine (dim, degree));
33+ substitutions.push_back (create_replicate_attention_reduce (dim, degree));
34+ }
35+ }
36+ for (nonnegative_int degree = 1_n; degree <= get_num_gpus (resources);
37+ degree *= 2_n) {
38+ substitutions.push_back (create_partition_conv2d_combine (4_n, degree));
39+ }
40+
41+ for (nonnegative_int partition_dim :
42+ nonnegative_range (1_n, nonnegative_int{MAX_TENSOR_DIM})) {
43+ for (nonnegative_int softmax_dim :
44+ nonnegative_range (1_n, nonnegative_int{MAX_TENSOR_DIM})) {
45+ for (nonnegative_int degree = 1_n; degree <= get_num_gpus (resources);
46+ degree *= 2_n) {
47+ if (partition_dim != softmax_dim) {
48+ substitutions.push_back (create_partition_softmax_combine (
49+ ff_dim_t {partition_dim}, ff_dim_t {softmax_dim}, degree));
50+ }
51+ }
2852 }
2953 }
3054 substitutions.push_back (create_fuse_linear_activation (Activation::RELU));
@@ -173,7 +197,7 @@ Substitution create_partition_linear_combine(nonnegative_int num_dims,
173197 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
174198 OperatorAttributeValue{degree}),
175199 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
176- OperatorAttributeValue{ff_dim_t {1_n }}),
200+ OperatorAttributeValue{ff_dim_t {0_n }}),
177201 }};
178202 OutputGraphExprValue o_partition_input_output =
179203 get_only (b.add_output_graph_node (partition_input_expr, {o_input}, 1_n));
@@ -265,13 +289,13 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
265289 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
266290 OperatorAttributeValue{degree}),
267291 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
268- OperatorAttributeValue{ff_dim_t {1_n }}),
292+ OperatorAttributeValue{ff_dim_t {0_n }}),
269293 }};
270294
271295 OutputGraphExprValue o_partition_input_output =
272296 get_only (b.add_output_graph_node (partition_input_expr, {o_input}, 1_n));
273297
274- /* OutputOperatorAttrsAssignment replicate_weights_expr =
298+ OutputOperatorAttrsAssignment replicate_weights_expr =
275299 OutputOperatorAttrsAssignment{
276300 std::nullopt ,
277301 {
@@ -283,10 +307,7 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
283307 b.add_output_graph_node (replicate_weights_expr, {o_weight}, 1_n));
284308
285309 std::vector<OutputGraphExprValue> o_conv2d_inputs = {
286- o_partition_input_output, o_replicate_weights_output};*/
287-
288- std::vector<OutputGraphExprValue> o_conv2d_inputs = {o_partition_input_output,
289- o_weight};
310+ o_partition_input_output, o_replicate_weights_output};
290311
291312 OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{
292313 b.pattern_node_named (" conv2d" ),
@@ -321,15 +342,16 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
321342
322343 SubstitutionBuilder b;
323344
324- auto [p_input, o_input] = b.add_input (tensor_attribute_pattern_match_all ());
325- auto [p_query_weight, o_query_weight] =
345+ auto [p_query_input, o_query_input] =
346+ b.add_input (tensor_attribute_pattern_match_all ());
347+ auto [p_key_input, o_key_input] =
326348 b.add_input (tensor_attribute_pattern_match_all ());
327- auto [p_key_weight, o_key_weight ] =
349+ auto [p_value_input, o_value_input ] =
328350 b.add_input (tensor_attribute_pattern_match_all ());
329- auto [p_value_weight, o_value_weight ] =
351+ auto [p_weights, o_weights ] =
330352 b.add_input (tensor_attribute_pattern_match_all ());
331353 std::vector<PatternValue> p_inputs = {
332- p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight };
354+ p_query_input, p_key_input, p_value_input, p_weights };
333355
334356 OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
335357 op_type_equals_constraint (OperatorType::MULTIHEAD_ATTENTION),
@@ -351,19 +373,35 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
351373 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
352374 OperatorAttributeValue{degree}),
353375 set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
354- OperatorAttributeValue{ff_dim_t {1_n }}),
376+ OperatorAttributeValue{ff_dim_t {0_n }}),
355377 }};
356378
357- OutputGraphExprValue o_partition_input_output =
358- get_only (b.add_output_graph_node (partition_input_expr, {o_input}, 1_n));
379+ OutputGraphExprValue o_partition_query_input_output = get_only (
380+ b.add_output_graph_node (partition_input_expr, {o_query_input}, 1_n));
381+
382+ OutputGraphExprValue o_partition_key_input_output = get_only (
383+ b.add_output_graph_node (partition_input_expr, {o_key_input}, 1_n));
384+
385+ OutputGraphExprValue o_partition_value_input_output = get_only (
386+ b.add_output_graph_node (partition_input_expr, {o_value_input}, 1_n));
387+
388+ OutputOperatorAttrsAssignment replicate_weight_expr =
389+ OutputOperatorAttrsAssignment{
390+ std::nullopt ,
391+ {
392+ set_op_type_attr (OperatorType::REPLICATE),
393+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
394+ OperatorAttributeValue{degree}),
395+ }};
396+
397+ OutputGraphExprValue o_replicate_weight_output = get_only (
398+ b.add_output_graph_node (replicate_weight_expr, {o_weights}, 1_n));
359399
360400 std::vector<OutputGraphExprValue> o_attention_inputs = {
361- o_partition_input_output,
362- o_partition_input_output,
363- o_partition_input_output,
364- o_query_weight,
365- o_key_weight,
366- o_value_weight};
401+ o_partition_query_input_output,
402+ o_partition_key_input_output,
403+ o_partition_value_input_output,
404+ o_replicate_weight_output};
367405
368406 OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
369407 b.pattern_node_named (" attention" ),
@@ -394,17 +432,19 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
394432
395433Substitution create_replicate_attention_reduce (nonnegative_int num_heads,
396434 nonnegative_int degree) {
435+
397436 SubstitutionBuilder b;
398437
399- auto [p_input, o_input] = b.add_input (tensor_attribute_pattern_match_all ());
400- auto [p_query_weight, o_query_weight] =
438+ auto [p_query_input, o_query_input] =
401439 b.add_input (tensor_attribute_pattern_match_all ());
402- auto [p_key_weight, o_key_weight ] =
440+ auto [p_key_input, o_key_input ] =
403441 b.add_input (tensor_attribute_pattern_match_all ());
404- auto [p_value_weight, o_value_weight] =
442+ auto [p_value_input, o_value_input] =
443+ b.add_input (tensor_attribute_pattern_match_all ());
444+ auto [p_weights, o_weights] =
405445 b.add_input (tensor_attribute_pattern_match_all ());
406446 std::vector<PatternValue> p_inputs = {
407- p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight };
447+ p_query_input, p_key_input, p_value_input, p_weights };
408448
409449 OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
410450 op_type_equals_constraint (OperatorType::MULTIHEAD_ATTENTION),
@@ -421,20 +461,40 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
421461 OutputOperatorAttrsAssignment replicate_input_expr =
422462 OutputOperatorAttrsAssignment{
423463 std::nullopt ,
424- {set_op_type_attr (OperatorType::REPLICATE),
425- set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
426- OperatorAttributeValue{degree})}};
464+ {
465+ set_op_type_attr (OperatorType::REPLICATE),
466+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
467+ OperatorAttributeValue{degree}),
468+ }};
427469
428- OutputGraphExprValue o_replicate_input_output =
429- get_only (b.add_output_graph_node (replicate_input_expr, {o_input}, 1_n));
470+ OutputGraphExprValue o_replicate_query_input_output = get_only (
471+ b.add_output_graph_node (replicate_input_expr, {o_query_input}, 1_n));
472+
473+ OutputGraphExprValue o_replicate_key_input_output = get_only (
474+ b.add_output_graph_node (replicate_input_expr, {o_key_input}, 1_n));
475+
476+ OutputGraphExprValue o_replicate_value_input_output = get_only (
477+ b.add_output_graph_node (replicate_input_expr, {o_value_input}, 1_n));
478+
479+ OutputOperatorAttrsAssignment partition_weight_expr =
480+ OutputOperatorAttrsAssignment{
481+ std::nullopt ,
482+ {
483+ set_op_type_attr (OperatorType::REPARTITION),
484+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
485+ OperatorAttributeValue{degree}),
486+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
487+ OperatorAttributeValue{ff_dim_t {1_n}}),
488+ }};
489+
490+ OutputGraphExprValue o_partition_weight_output = get_only (
491+ b.add_output_graph_node (partition_weight_expr, {o_weights}, 1_n));
430492
431493 std::vector<OutputGraphExprValue> o_attention_inputs = {
432- o_replicate_input_output,
433- o_replicate_input_output,
434- o_replicate_input_output,
435- o_query_weight,
436- o_key_weight,
437- o_value_weight};
494+ o_replicate_query_input_output,
495+ o_replicate_key_input_output,
496+ o_replicate_value_input_output,
497+ o_partition_weight_output};
438498
439499 OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
440500 b.pattern_node_named (" attention" ),
@@ -451,14 +511,83 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
451511 OperatorAttributeValue{degree}),
452512 },
453513 };
454- OutputGraphExprValue o_reduce_output = get_only (
455- b.add_output_graph_node (reduce_expr, {o_attention_output}, 1_n));
514+ OutputGraphExprValue o_reduce_output =
515+ get_only ( b.add_output_graph_node (reduce_expr, {o_attention_output}, 1_n));
456516
457517 b.equate_outputs (p_attention_output, o_reduce_output);
458518
459519 return b.get_substitution ();
460520}
461521
522+ Substitution create_partition_softmax_combine (ff_dim_t softmax_dim,
523+ ff_dim_t partition_dim,
524+ nonnegative_int degree) {
525+ if (partition_dim == softmax_dim) {
526+ throw mk_runtime_error (
527+ fmt::format (" partition dim {} must not be equal to softmax dim {}" ,
528+ partition_dim,
529+ softmax_dim));
530+ }
531+ SubstitutionBuilder b;
532+
533+ auto [p_input, o_input] = b.add_input (tensor_attribute_pattern_match_all ());
534+ std::vector<PatternValue> p_inputs = {p_input};
535+
536+ OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{
537+ op_type_equals_constraint (OperatorType::SOFTMAX),
538+ op_attr_key_divisible_by (OperatorAttributeKey::OUT_CHANNELS, degree),
539+ op_attr_key_divisible_by (OperatorAttributeKey::SOFTMAX_DIM,
540+ softmax_dim.value ),
541+ }};
542+
543+ PatternValue p_softmax_output =
544+ get_only (b.add_pattern_node (softmax_pattern,
545+ p_inputs,
546+ {tensor_attribute_pattern_match_all ()},
547+ " softmax" ));
548+
549+ OutputOperatorAttrsAssignment partition_input_expr =
550+ OutputOperatorAttrsAssignment{
551+ std::nullopt ,
552+ {
553+ set_op_type_attr (OperatorType::REPARTITION),
554+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
555+ OperatorAttributeValue{degree}),
556+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
557+ OperatorAttributeValue{partition_dim}),
558+ }};
559+
560+ OutputGraphExprValue o_partition_input_output =
561+ get_only (b.add_output_graph_node (partition_input_expr, {o_input}, 1_n));
562+
563+ std::vector<OutputGraphExprValue> o_softmax_inputs = {
564+ o_partition_input_output};
565+
566+ OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{
567+ b.pattern_node_named (" softmax" ),
568+ {},
569+ };
570+ OutputGraphExprValue o_softmax_output =
571+ get_only (b.add_output_graph_node (softmax_expr, o_softmax_inputs, 1_n));
572+
573+ OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{
574+ std::nullopt ,
575+ {
576+ set_op_type_attr (OperatorType::COMBINE),
577+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
578+ OperatorAttributeValue{degree}),
579+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
580+ OperatorAttributeValue{partition_dim}),
581+ },
582+ };
583+ OutputGraphExprValue o_combine_output =
584+ get_only (b.add_output_graph_node (combine_expr, {o_softmax_output}, 1_n));
585+
586+ b.equate_outputs (p_softmax_output, o_combine_output);
587+
588+ return b.get_substitution ();
589+ }
590+
462591Substitution create_partition_add_combine (ff_dim_t parallel_dim,
463592 nonnegative_int degree) {
464593 SubstitutionBuilder b;
0 commit comments