Skip to content

Commit 4fe85ed

Browse files
author
Victor Li
committed
Updating unity substitution set to work with new PCG interface
1 parent 1394173 commit 4fe85ed

File tree

6 files changed

+1064
-870
lines changed

6 files changed

+1064
-870
lines changed

.envrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
source_up_if_exists
2+
3+
use flake

.vimrc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
" example search path configuration
2+
set path=lib/runtime/**,lib/**
3+
4+
" set build target
5+
" let g:target = "pcg"
6+
7+
" set test target
8+
" let g:test_target = "utils-test"
File renamed without changes.

lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
125125
}};
126126
case OperatorType::REDUCTION:
127127
return PCGOperatorAttrs{ReductionAttrs{
128-
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
128+
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
129+
}};
130+
case OperatorType::SOFTMAX:
131+
return PCGOperatorAttrs{SoftmaxAttrs{
132+
acc.get<ff_dim_t>(OperatorAttributeKey::AXIS),
129133
}};
130134
case OperatorType::BATCHMATMUL:
131135
case OperatorType::SCALAR_MULTIPLY:
@@ -138,7 +142,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
138142
case OperatorType::TANH:
139143
case OperatorType::ELU:
140144
case OperatorType::FLAT:
141-
case OperatorType::SOFTMAX:
142145
case OperatorType::BATCHNORM:
143146
case OperatorType::CONCAT:
144147
case OperatorType::SPLIT:

lib/substitutions/src/substitutions/unity_substitution_set.cc

Lines changed: 173 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,42 @@ namespace FlexFlow {
1313
std::vector<Substitution>
1414
get_substitution_set(MachineSpecification const &resources) {
1515
std::vector<Substitution> substitutions;
16-
for (nonnegative_int num_dims :
16+
for (nonnegative_int dim :
1717
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
1818
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
1919
degree *= 2_n) {
2020
substitutions.push_back(
21-
create_replicate_linear_combine(num_dims, degree, true));
21+
create_replicate_linear_combine(dim, degree, true));
2222
substitutions.push_back(
23-
create_replicate_linear_combine(num_dims, degree, false));
23+
create_replicate_linear_combine(dim, degree, false));
2424
substitutions.push_back(
25-
create_partition_linear_combine(num_dims, degree, true));
25+
create_partition_linear_combine(dim, degree, true));
2626
substitutions.push_back(
27-
create_partition_linear_combine(num_dims, degree, false));
27+
create_partition_linear_combine(dim, degree, false));
28+
substitutions.push_back(
29+
create_partition_relu_combine(ff_dim_t{dim}, degree));
30+
substitutions.push_back(
31+
create_partition_add_combine(ff_dim_t{dim}, degree));
32+
substitutions.push_back(create_partition_attention_combine(dim, degree));
33+
substitutions.push_back(create_replicate_attention_reduce(dim, degree));
34+
}
35+
}
36+
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
37+
degree *= 2_n) {
38+
substitutions.push_back(create_partition_conv2d_combine(4_n, degree));
39+
}
40+
41+
for (nonnegative_int partition_dim :
42+
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
43+
for (nonnegative_int softmax_dim :
44+
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
45+
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
46+
degree *= 2_n) {
47+
if (partition_dim != softmax_dim) {
48+
substitutions.push_back(create_partition_softmax_combine(
49+
ff_dim_t{partition_dim}, ff_dim_t{softmax_dim}, degree));
50+
}
51+
}
2852
}
2953
}
3054
substitutions.push_back(create_fuse_linear_activation(Activation::RELU));
@@ -173,7 +197,7 @@ Substitution create_partition_linear_combine(nonnegative_int num_dims,
173197
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
174198
OperatorAttributeValue{degree}),
175199
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
176-
OperatorAttributeValue{ff_dim_t{1_n}}),
200+
OperatorAttributeValue{ff_dim_t{0_n}}),
177201
}};
178202
OutputGraphExprValue o_partition_input_output =
179203
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
@@ -265,13 +289,13 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
265289
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
266290
OperatorAttributeValue{degree}),
267291
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
268-
OperatorAttributeValue{ff_dim_t{1_n}}),
292+
OperatorAttributeValue{ff_dim_t{0_n}}),
269293
}};
270294

271295
OutputGraphExprValue o_partition_input_output =
272296
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
273297

274-
/*OutputOperatorAttrsAssignment replicate_weights_expr =
298+
OutputOperatorAttrsAssignment replicate_weights_expr =
275299
OutputOperatorAttrsAssignment{
276300
std::nullopt,
277301
{
@@ -283,10 +307,7 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
283307
b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n));
284308

285309
std::vector<OutputGraphExprValue> o_conv2d_inputs = {
286-
o_partition_input_output, o_replicate_weights_output};*/
287-
288-
std::vector<OutputGraphExprValue> o_conv2d_inputs = {o_partition_input_output,
289-
o_weight};
310+
o_partition_input_output, o_replicate_weights_output};
290311

291312
OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{
292313
b.pattern_node_named("conv2d"),
@@ -321,15 +342,16 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
321342

322343
SubstitutionBuilder b;
323344

324-
auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
325-
auto [p_query_weight, o_query_weight] =
345+
auto [p_query_input, o_query_input] =
346+
b.add_input(tensor_attribute_pattern_match_all());
347+
auto [p_key_input, o_key_input] =
326348
b.add_input(tensor_attribute_pattern_match_all());
327-
auto [p_key_weight, o_key_weight] =
349+
auto [p_value_input, o_value_input] =
328350
b.add_input(tensor_attribute_pattern_match_all());
329-
auto [p_value_weight, o_value_weight] =
351+
auto [p_weights, o_weights] =
330352
b.add_input(tensor_attribute_pattern_match_all());
331353
std::vector<PatternValue> p_inputs = {
332-
p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight};
354+
p_query_input, p_key_input, p_value_input, p_weights};
333355

334356
OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
335357
op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION),
@@ -351,19 +373,35 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
351373
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
352374
OperatorAttributeValue{degree}),
353375
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
354-
OperatorAttributeValue{ff_dim_t{1_n}}),
376+
OperatorAttributeValue{ff_dim_t{0_n}}),
355377
}};
356378

357-
OutputGraphExprValue o_partition_input_output =
358-
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
379+
OutputGraphExprValue o_partition_query_input_output = get_only(
380+
b.add_output_graph_node(partition_input_expr, {o_query_input}, 1_n));
381+
382+
OutputGraphExprValue o_partition_key_input_output = get_only(
383+
b.add_output_graph_node(partition_input_expr, {o_key_input}, 1_n));
384+
385+
OutputGraphExprValue o_partition_value_input_output = get_only(
386+
b.add_output_graph_node(partition_input_expr, {o_value_input}, 1_n));
387+
388+
OutputOperatorAttrsAssignment replicate_weight_expr =
389+
OutputOperatorAttrsAssignment{
390+
std::nullopt,
391+
{
392+
set_op_type_attr(OperatorType::REPLICATE),
393+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
394+
OperatorAttributeValue{degree}),
395+
}};
396+
397+
OutputGraphExprValue o_replicate_weight_output = get_only(
398+
b.add_output_graph_node(replicate_weight_expr, {o_weights}, 1_n));
359399

360400
std::vector<OutputGraphExprValue> o_attention_inputs = {
361-
o_partition_input_output,
362-
o_partition_input_output,
363-
o_partition_input_output,
364-
o_query_weight,
365-
o_key_weight,
366-
o_value_weight};
401+
o_partition_query_input_output,
402+
o_partition_key_input_output,
403+
o_partition_value_input_output,
404+
o_replicate_weight_output};
367405

368406
OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
369407
b.pattern_node_named("attention"),
@@ -394,17 +432,19 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
394432

395433
Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
396434
nonnegative_int degree) {
435+
397436
SubstitutionBuilder b;
398437

399-
auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
400-
auto [p_query_weight, o_query_weight] =
438+
auto [p_query_input, o_query_input] =
401439
b.add_input(tensor_attribute_pattern_match_all());
402-
auto [p_key_weight, o_key_weight] =
440+
auto [p_key_input, o_key_input] =
403441
b.add_input(tensor_attribute_pattern_match_all());
404-
auto [p_value_weight, o_value_weight] =
442+
auto [p_value_input, o_value_input] =
443+
b.add_input(tensor_attribute_pattern_match_all());
444+
auto [p_weights, o_weights] =
405445
b.add_input(tensor_attribute_pattern_match_all());
406446
std::vector<PatternValue> p_inputs = {
407-
p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight};
447+
p_query_input, p_key_input, p_value_input, p_weights};
408448

409449
OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
410450
op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION),
@@ -421,20 +461,40 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
421461
OutputOperatorAttrsAssignment replicate_input_expr =
422462
OutputOperatorAttrsAssignment{
423463
std::nullopt,
424-
{set_op_type_attr(OperatorType::REPLICATE),
425-
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
426-
OperatorAttributeValue{degree})}};
464+
{
465+
set_op_type_attr(OperatorType::REPLICATE),
466+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
467+
OperatorAttributeValue{degree}),
468+
}};
427469

428-
OutputGraphExprValue o_replicate_input_output =
429-
get_only(b.add_output_graph_node(replicate_input_expr, {o_input}, 1_n));
470+
OutputGraphExprValue o_replicate_query_input_output = get_only(
471+
b.add_output_graph_node(replicate_input_expr, {o_query_input}, 1_n));
472+
473+
OutputGraphExprValue o_replicate_key_input_output = get_only(
474+
b.add_output_graph_node(replicate_input_expr, {o_key_input}, 1_n));
475+
476+
OutputGraphExprValue o_replicate_value_input_output = get_only(
477+
b.add_output_graph_node(replicate_input_expr, {o_value_input}, 1_n));
478+
479+
OutputOperatorAttrsAssignment partition_weight_expr =
480+
OutputOperatorAttrsAssignment{
481+
std::nullopt,
482+
{
483+
set_op_type_attr(OperatorType::REPARTITION),
484+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
485+
OperatorAttributeValue{degree}),
486+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
487+
OperatorAttributeValue{ff_dim_t{1_n}}),
488+
}};
489+
490+
OutputGraphExprValue o_partition_weight_output = get_only(
491+
b.add_output_graph_node(partition_weight_expr, {o_weights}, 1_n));
430492

431493
std::vector<OutputGraphExprValue> o_attention_inputs = {
432-
o_replicate_input_output,
433-
o_replicate_input_output,
434-
o_replicate_input_output,
435-
o_query_weight,
436-
o_key_weight,
437-
o_value_weight};
494+
o_replicate_query_input_output,
495+
o_replicate_key_input_output,
496+
o_replicate_value_input_output,
497+
o_partition_weight_output};
438498

439499
OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
440500
b.pattern_node_named("attention"),
@@ -451,14 +511,83 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
451511
OperatorAttributeValue{degree}),
452512
},
453513
};
454-
OutputGraphExprValue o_reduce_output = get_only(
455-
b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n));
514+
OutputGraphExprValue o_reduce_output =
515+
get_only(b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n));
456516

457517
b.equate_outputs(p_attention_output, o_reduce_output);
458518

459519
return b.get_substitution();
460520
}
461521

522+
Substitution create_partition_softmax_combine(ff_dim_t softmax_dim,
523+
ff_dim_t partition_dim,
524+
nonnegative_int degree) {
525+
if (partition_dim == softmax_dim) {
526+
throw mk_runtime_error(
527+
fmt::format("partition dim {} must not be equal to softmax dim {}",
528+
partition_dim,
529+
softmax_dim));
530+
}
531+
SubstitutionBuilder b;
532+
533+
auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
534+
std::vector<PatternValue> p_inputs = {p_input};
535+
536+
OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{
537+
op_type_equals_constraint(OperatorType::SOFTMAX),
538+
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree),
539+
op_attr_key_divisible_by(OperatorAttributeKey::SOFTMAX_DIM,
540+
softmax_dim.value),
541+
}};
542+
543+
PatternValue p_softmax_output =
544+
get_only(b.add_pattern_node(softmax_pattern,
545+
p_inputs,
546+
{tensor_attribute_pattern_match_all()},
547+
"softmax"));
548+
549+
OutputOperatorAttrsAssignment partition_input_expr =
550+
OutputOperatorAttrsAssignment{
551+
std::nullopt,
552+
{
553+
set_op_type_attr(OperatorType::REPARTITION),
554+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
555+
OperatorAttributeValue{degree}),
556+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
557+
OperatorAttributeValue{partition_dim}),
558+
}};
559+
560+
OutputGraphExprValue o_partition_input_output =
561+
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
562+
563+
std::vector<OutputGraphExprValue> o_softmax_inputs = {
564+
o_partition_input_output};
565+
566+
OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{
567+
b.pattern_node_named("softmax"),
568+
{},
569+
};
570+
OutputGraphExprValue o_softmax_output =
571+
get_only(b.add_output_graph_node(softmax_expr, o_softmax_inputs, 1_n));
572+
573+
OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{
574+
std::nullopt,
575+
{
576+
set_op_type_attr(OperatorType::COMBINE),
577+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
578+
OperatorAttributeValue{degree}),
579+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
580+
OperatorAttributeValue{partition_dim}),
581+
},
582+
};
583+
OutputGraphExprValue o_combine_output =
584+
get_only(b.add_output_graph_node(combine_expr, {o_softmax_output}, 1_n));
585+
586+
b.equate_outputs(p_softmax_output, o_combine_output);
587+
588+
return b.get_substitution();
589+
}
590+
462591
Substitution create_partition_add_combine(ff_dim_t parallel_dim,
463592
nonnegative_int degree) {
464593
SubstitutionBuilder b;

0 commit comments

Comments
 (0)