Skip to content

Commit c9c6432

Browse files
committed
porting unity substitutions from old branch
1 parent 5abb225 commit c9c6432

File tree

10 files changed

+2205
-98
lines changed

10 files changed

+2205
-98
lines changed

lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ OperatorAttributeConstraint op_type_equals_constraint(OperatorType);
99

1010
OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey,
1111
OperatorAttributeValue const &);
12-
OperatorAttributeConstraint
13-
op_attr_key_divisible_by(OperatorAttributeKey, nonnegative_int denominator);
12+
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey,
13+
positive_int denominator);
1414
OperatorAttributeConstraint
1515
make_equals_constraint(OperatorAttributeExpr const &,
1616
OperatorAttributeValue const &);

lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H
33

44
#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h"
5-
#include "utils/nonnegative_int/nonnegative_int.h"
5+
#include "utils/positive_int/positive_int.h"
66

77
namespace FlexFlow {
88

99
TensorAttributePattern tensor_attribute_pattern_match_all();
1010
TensorAttributePattern
11-
tensor_attr_pattern_require_num_dims(nonnegative_int num_dims);
11+
tensor_attr_pattern_require_num_dims(positive_int num_dims);
1212

1313
} // namespace FlexFlow
1414

lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@ includes = [
1313
"utils/hash/vector.h",
1414
"utils/fmt/vector.h",
1515
"utils/nonnegative_int/nonnegative_int.h",
16+
"utils/positive_int/positive_int.h",
1617
]
1718

1819
[[values]]
1920
type = "::FlexFlow::nonnegative_int"
2021

22+
[[values]]
23+
type = "::FlexFlow::positive_int"
24+
2125
[[values]]
2226
type = "std::vector<::FlexFlow::nonnegative_int>"
27+
28+
[[values]]
29+
type = "std::vector<::FlexFlow::positive_int>"
30+

lib/substitutions/include/substitutions/unity_substitution_set.h

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,25 @@ namespace FlexFlow {
1010
std::vector<Substitution>
1111
get_substitution_set(MachineSpecification const &resources);
1212

13-
Substitution create_combine_inception(nonnegative_int num_convs,
14-
nonnegative_int num_dims,
15-
nonnegative_int degree);
16-
Substitution create_combine_concat(nonnegative_int num_inputs,
17-
nonnegative_int num_dims,
18-
nonnegative_int degree);
19-
Substitution create_replicate_linear_combine(nonnegative_int num_dims,
20-
nonnegative_int degree,
13+
Substitution create_replicate_linear_combine(positive_int num_dims,
14+
positive_int degree,
2115
bool use_bias);
22-
Substitution create_partition_linear_combine(nonnegative_int num_dims,
23-
nonnegative_int degree,
24-
Activation activation,
16+
Substitution create_partition_linear_combine(positive_int num_dims,
17+
positive_int degree,
2518
bool use_bias);
26-
Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
27-
nonnegative_int degree);
28-
Substitution create_partition_attention_combine(nonnegative_int num_heads,
29-
nonnegative_int degree);
30-
Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
31-
nonnegative_int degree);
19+
Substitution create_partition_conv2d_combine(positive_int num_dims,
20+
positive_int degree);
21+
Substitution create_partition_attention_combine(positive_int num_heads,
22+
positive_int degree);
23+
Substitution create_replicate_attention_reduce(positive_int num_heads,
24+
positive_int degree);
3225
Substitution create_partition_add_combine(ff_dim_t parallel_dim,
33-
nonnegative_int degree);
26+
positive_int degree);
3427
Substitution create_partition_relu_combine(ff_dim_t parallel_dim,
35-
nonnegative_int degree);
36-
Substitution create_partition_concat_combine(nonnegative_int num_inputs,
37-
ff_dim_t concat_dim,
38-
ff_dim_t parallel_dim,
39-
nonnegative_int degree);
28+
positive_int degree);
4029
Substitution create_partition_softmax_combine(ff_dim_t softmax_dim,
4130
ff_dim_t partition_dim,
42-
nonnegative_int degree);
31+
positive_int degree);
4332
Substitution create_fuse_linear_activation(Activation activation);
4433

4534
} // namespace FlexFlow

lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ std::optional<OperatorAttributeValue> get_attribute(ConcatAttrs const &p,
8383
std::optional<OperatorAttributeValue> get_attribute(Conv2DAttrs const &p,
8484
OperatorAttributeKey key) {
8585
switch (key) {
86+
case OperatorAttributeKey::OUT_CHANNELS:
87+
return OperatorAttributeValue{p.out_channels};
8688
case OperatorAttributeKey::OP_TYPE:
8789
return OperatorAttributeValue{get_op_type(p)};
8890
case OperatorAttributeKey::KERNEL_H:
@@ -113,6 +115,12 @@ std::optional<OperatorAttributeValue> get_attribute(ElementBinaryAttrs const &p,
113115
switch (key) {
114116
case OperatorAttributeKey::OP_TYPE:
115117
return OperatorAttributeValue{get_op_type(p)};
118+
case OperatorAttributeKey::DATA_TYPE:
119+
return OperatorAttributeValue{p.compute_type};
120+
case OperatorAttributeKey::SHOULD_BROADCAST_LHS:
121+
return OperatorAttributeValue{p.should_broadcast_lhs};
122+
case OperatorAttributeKey::SHOULD_BROADCAST_RHS:
123+
return OperatorAttributeValue{p.should_broadcast_rhs};
116124
default:
117125
return std::nullopt;
118126
}
@@ -123,6 +131,8 @@ std::optional<OperatorAttributeValue> get_attribute(ElementUnaryAttrs const &p,
123131
switch (key) {
124132
case OperatorAttributeKey::OP_TYPE:
125133
return OperatorAttributeValue{get_op_type(p)};
134+
case OperatorAttributeKey::SCALAR:
135+
return OperatorAttributeValue{p.scalar};
126136
default:
127137
return std::nullopt;
128138
}
@@ -227,10 +237,20 @@ std::optional<OperatorAttributeValue>
227237
switch (key) {
228238
case OperatorAttributeKey::OP_TYPE:
229239
return OperatorAttributeValue{get_op_type(p)};
240+
case OperatorAttributeKey::EMBED_DIM:
241+
return OperatorAttributeValue{p.embed_dim};
242+
case OperatorAttributeKey::KDIM:
243+
return OperatorAttributeValue{p.kdim};
244+
case OperatorAttributeKey::VDIM:
245+
return OperatorAttributeValue{p.vdim};
230246
case OperatorAttributeKey::NUM_HEADS:
231247
return OperatorAttributeValue{p.num_heads};
232-
case OperatorAttributeKey::USE_BIAS:
248+
case OperatorAttributeKey::BIAS:
233249
return OperatorAttributeValue{p.bias};
250+
case OperatorAttributeKey::ADD_BIAS_KV:
251+
return OperatorAttributeValue{p.add_bias_kv};
252+
case OperatorAttributeKey::ADD_ZERO_ATTN:
253+
return OperatorAttributeValue{p.add_bias_kv};
234254
case OperatorAttributeKey::DROPOUT:
235255
return OperatorAttributeValue{p.dropout};
236256
default:

lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ OperatorAttributeConstraint
2020
};
2121
}
2222

23-
OperatorAttributeConstraint
24-
op_attr_key_divisible_by(OperatorAttributeKey key,
25-
nonnegative_int denominator) {
23+
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key,
24+
positive_int denominator) {
2625
return OperatorAttributeConstraint{
2726
ConstraintType::DIVISIBLE_BY,
2827
OperatorAttributeExpr{key},

lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
6161
case OperatorType::NOOP:
6262
case OperatorType::INPUT:
6363
case OperatorType::WEIGHT:
64-
case OperatorType::CONV2D:
6564
case OperatorType::DROPOUT:
6665
case OperatorType::LINEAR:
6766
return PCGOperatorAttrs{LinearAttrs{
@@ -75,19 +74,72 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
7574
acc.get<std::optional<RegularizerAttrs>>(
7675
OperatorAttributeKey::REGULARIZER),
7776
}};
77+
case OperatorType::CONV2D:
78+
return PCGOperatorAttrs{Conv2DAttrs{
79+
/*out_channels=*/acc.get<positive_int>(
80+
OperatorAttributeKey::OUT_CHANNELS),
81+
/*kernel_h=*/acc.get<positive_int>(OperatorAttributeKey::KERNEL_H),
82+
/*kernel_w=*/acc.get<positive_int>(OperatorAttributeKey::KERNEL_W),
83+
/*stride_h=*/acc.get<positive_int>(OperatorAttributeKey::STRIDE_H),
84+
/*stride_w=*/acc.get<positive_int>(OperatorAttributeKey::STRIDE_W),
85+
/*padding_h=*/
86+
acc.get<nonnegative_int>(OperatorAttributeKey::PADDING_H),
87+
/*padding_w=*/
88+
acc.get<nonnegative_int>(OperatorAttributeKey::PADDING_W),
89+
/*groups=*/acc.get<positive_int>(OperatorAttributeKey::GROUPS),
90+
/*activation=*/
91+
acc.get<std::optional<Activation>>(OperatorAttributeKey::ACTIVATION),
92+
/*use_bias=*/acc.get<bool>(OperatorAttributeKey::USE_BIAS),
93+
}};
94+
case OperatorType::RELU:
95+
return PCGOperatorAttrs{ElementUnaryAttrs{
96+
acc.get<OperatorType>(OperatorAttributeKey::OP_TYPE),
97+
acc.get<std::optional<float>>(OperatorAttributeKey::SCALAR),
98+
}};
99+
case OperatorType::SOFTMAX:
100+
return PCGOperatorAttrs{SoftmaxAttrs{
101+
acc.get<ff_dim_t>(OperatorAttributeKey::AXIS),
102+
}};
103+
case OperatorType::EW_ADD:
104+
return PCGOperatorAttrs{ElementBinaryAttrs{
105+
acc.get<OperatorType>(OperatorAttributeKey::OP_TYPE),
106+
acc.get<DataType>(OperatorAttributeKey::DATA_TYPE),
107+
acc.get<bool>(OperatorAttributeKey::SHOULD_BROADCAST_LHS),
108+
acc.get<bool>(OperatorAttributeKey::SHOULD_BROADCAST_LHS),
109+
}};
110+
case OperatorType::REPLICATE:
111+
return PCGOperatorAttrs{ReplicateAttrs{
112+
/*replicate_degree=*/acc.get<positive_int>(
113+
OperatorAttributeKey::PARALLEL_DEGREE),
114+
}};
115+
case OperatorType::REPARTITION:
116+
return PCGOperatorAttrs{RepartitionAttrs{
117+
/*repartition_dim=*/acc.get<ff_dim_t>(
118+
OperatorAttributeKey::PARALLEL_DIM),
119+
/*repartition_Degree=*/
120+
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
121+
}};
122+
case OperatorType::COMBINE:
123+
return PCGOperatorAttrs{CombineAttrs{
124+
/*combine_dim=*/acc.get<ff_dim_t>(OperatorAttributeKey::PARALLEL_DIM),
125+
/*combine_degree=*/
126+
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
127+
}};
128+
case OperatorType::REDUCTION:
129+
return PCGOperatorAttrs{ReductionAttrs{
130+
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
131+
}};
78132
case OperatorType::BATCHMATMUL:
79133
case OperatorType::SCALAR_MULTIPLY:
80134
case OperatorType::SCALAR_ADD:
81135
case OperatorType::SCALAR_FLOOR_DIV:
82136
case OperatorType::SCALAR_TRUE_DIV:
83137
case OperatorType::SCALAR_SUB:
84-
case OperatorType::RELU:
85138
case OperatorType::IDENTITY:
86139
case OperatorType::SIGMOID:
87140
case OperatorType::TANH:
88141
case OperatorType::ELU:
89142
case OperatorType::FLAT:
90-
case OperatorType::SOFTMAX:
91143
case OperatorType::BATCHNORM:
92144
case OperatorType::CONCAT:
93145
case OperatorType::SPLIT:
@@ -96,7 +148,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
96148
case OperatorType::RESHAPE:
97149
case OperatorType::REVERSE:
98150
case OperatorType::TRANSPOSE:
99-
case OperatorType::EW_ADD:
100151
case OperatorType::EW_MUL:
101152
case OperatorType::MATMUL:
102153
case OperatorType::MUL:
@@ -143,10 +194,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
143194
case OperatorType::LAYERNORM:
144195
case OperatorType::GATHER:
145196
case OperatorType::BROADCAST:
146-
case OperatorType::REPARTITION:
147-
case OperatorType::COMBINE:
148-
case OperatorType::REPLICATE:
149-
case OperatorType::REDUCTION:
150197
case OperatorType::BATCH:
151198
case OperatorType::PIPELINE:
152199
case OperatorType::FUSED_PARALLEL:

lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ TensorAttributePattern tensor_attribute_pattern_match_all() {
88
}
99

1010
TensorAttributePattern
11-
tensor_attr_pattern_require_num_dims(nonnegative_int num_dims) {
11+
tensor_attr_pattern_require_num_dims(positive_int num_dims) {
1212
return TensorAttributePattern{{
1313
TensorAttributeConstraint{
1414
ConstraintType::EQUAL,

0 commit comments

Comments
 (0)