Skip to content

Commit 0dd487c

Browse files
author
Victor Li
committed
Adding more to unity substitution set
1 parent 600e074 commit 0dd487c

File tree

4 files changed

+543
-11
lines changed

4 files changed

+543
-11
lines changed

lib/op-attrs/src/op-attrs/get_output_shapes.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "op-attrs/ops/layer_norm.h"
1616
#include "op-attrs/ops/linear.h"
1717
#include "op-attrs/ops/pool_2d.h"
18+
#include "op-attrs/ops/repartition.h"
1819
#include "op-attrs/ops/replicate.h"
1920
#include "op-attrs/ops/weight.h"
2021
#include "utils/overload.h"
@@ -78,6 +79,9 @@ std::vector<ParallelTensorShape>
7879
[&](ReplicateAttrs const &attrs) -> std::vector<ParallelTensorShape> {
7980
return {get_output_shape(attrs, inputs.at(0))};
8081
},
82+
[&](RepartitionAttrs const &attrs) -> std::vector<ParallelTensorShape> {
83+
return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))};
84+
},
8185
[&](WeightAttrs const &attrs) -> std::vector<ParallelTensorShape> {
8286
return {get_output_parallel_tensor_shape(attrs)};
8387
},

lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
7676
acc.get<std::optional<RegularizerAttrs>>(
7777
OperatorAttributeKey::REGULARIZER),
7878
}};
79+
case OperatorType::REPLICATE:
80+
return PCGOperatorAttrs{ReplicateAttrs{
81+
/*replicate_degree=*/acc.get<nonnegative_int>(
82+
OperatorAttributeKey::PARALLEL_DEGREE),
83+
}};
84+
case OperatorType::REPARTITION:
85+
return PCGOperatorAttrs{RepartitionAttrs{
86+
/*repartition_dim=*/acc.get<ff_dim_t>(
87+
OperatorAttributeKey::PARALLEL_DIM),
88+
/*repartition_Degree=*/
89+
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
90+
}};
91+
case OperatorType::COMBINE:
92+
return PCGOperatorAttrs{CombineAttrs{
93+
/*combine_dim=*/acc.get<ff_dim_t>(OperatorAttributeKey::PARALLEL_DIM),
94+
/*combine_degree=*/
95+
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
96+
}};
7997
case OperatorType::BATCHMATMUL:
8098
case OperatorType::SCALAR_MULTIPLY:
8199
case OperatorType::SCALAR_ADD:
@@ -144,9 +162,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
144162
case OperatorType::LAYERNORM:
145163
case OperatorType::GATHER:
146164
case OperatorType::BROADCAST:
147-
case OperatorType::REPARTITION:
148-
case OperatorType::COMBINE:
149-
case OperatorType::REPLICATE:
150165
case OperatorType::REDUCTION:
151166
case OperatorType::BATCH:
152167
case OperatorType::PIPELINE:

lib/substitutions/src/substitutions/unity_substitution_set.cc

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,14 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims,
6363
op_type_equals_constraint(OperatorType::LINEAR),
6464
op_attr_key_equals(OperatorAttributeKey::BIAS,
6565
OperatorAttributeValue{use_bias}),
66-
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS,
67-
nonnegative_int{degree}),
66+
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree),
6867
}};
6968

70-
PatternValue p_linear_output = get_only(b.add_pattern_node(
71-
linear_pattern,
72-
p_inputs,
73-
{tensor_attr_pattern_require_num_dims(nonnegative_int{num_dims})},
74-
"linear"));
69+
PatternValue p_linear_output = get_only(
70+
b.add_pattern_node(linear_pattern,
71+
p_inputs,
72+
{tensor_attr_pattern_require_num_dims(num_dims)},
73+
"linear"));
7574

7675
OutputOperatorAttrsAssignment replicate_input_expr =
7776
OutputOperatorAttrsAssignment{
@@ -148,7 +147,100 @@ Substitution create_partition_linear_combine(nonnegative_int num_dims,
148147
nonnegative_int degree,
149148
Activation activation,
150149
bool use_bias) {
151-
NOT_IMPLEMENTED();
150+
SubstitutionBuilder b;
151+
152+
auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
153+
auto [p_weight, o_weight] = b.add_input(tensor_attribute_pattern_match_all());
154+
std::vector<PatternValue> p_inputs = {p_input, p_weight};
155+
156+
std::optional<OutputGraphExprValue> o_bias = std::nullopt;
157+
if (use_bias) {
158+
std::pair<PatternValue, OutputGraphExprValue> bias =
159+
b.add_input(tensor_attribute_pattern_match_all());
160+
p_inputs.push_back(bias.first);
161+
o_bias = bias.second;
162+
}
163+
164+
OperatorAttributePattern linear_pattern = OperatorAttributePattern{{
165+
op_type_equals_constraint(OperatorType::LINEAR),
166+
op_attr_key_equals(OperatorAttributeKey::BIAS,
167+
OperatorAttributeValue{use_bias}),
168+
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree),
169+
}};
170+
171+
PatternValue p_linear_output = get_only(
172+
b.add_pattern_node(linear_pattern,
173+
p_inputs,
174+
{tensor_attr_pattern_require_num_dims(num_dims)},
175+
"linear"));
176+
177+
OutputOperatorAttrsAssignment partition_input_expr =
178+
OutputOperatorAttrsAssignment{
179+
std::nullopt,
180+
{
181+
set_op_type_attr(OperatorType::REPARTITION),
182+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
183+
OperatorAttributeValue{degree}),
184+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
185+
OperatorAttributeValue{ff_dim_t{1_n}}),
186+
}};
187+
OutputGraphExprValue o_partition_input_output =
188+
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
189+
190+
OutputOperatorAttrsAssignment replicate_weights_expr =
191+
OutputOperatorAttrsAssignment{
192+
std::nullopt,
193+
{
194+
set_op_type_attr(OperatorType::REPLICATE),
195+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
196+
OperatorAttributeValue{degree}),
197+
}};
198+
OutputGraphExprValue o_replicate_weights_output = get_only(
199+
b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n));
200+
201+
std::vector<OutputGraphExprValue> o_linear_inputs = {
202+
o_partition_input_output, o_replicate_weights_output};
203+
204+
if (use_bias) {
205+
OutputOperatorAttrsAssignment replicate_bias_expr =
206+
OutputOperatorAttrsAssignment{
207+
std::nullopt,
208+
{
209+
set_op_type_attr(OperatorType::REPLICATE),
210+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
211+
OperatorAttributeValue{degree}),
212+
}};
213+
OutputGraphExprValue o_replicate_bias_output = get_only(
214+
b.add_output_graph_node(replicate_bias_expr, {o_bias.value()}, 1_n));
215+
o_linear_inputs.push_back(o_replicate_bias_output);
216+
}
217+
218+
OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{
219+
b.pattern_node_named("linear"),
220+
{},
221+
};
222+
OutputGraphExprValue o_linear_output =
223+
get_only(b.add_output_graph_node(linear_expr, o_linear_inputs, 1_n));
224+
225+
OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{
226+
std::nullopt,
227+
{
228+
set_op_type_attr(OperatorType::COMBINE),
229+
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
230+
OperatorAttributeValue{degree}),
231+
set_attr_to_constant(
232+
OperatorAttributeKey::PARALLEL_DIM,
233+
OperatorAttributeValue{ff_dim_t{
234+
nonnegative_int{num_dims.unwrap_nonnegative() - 1},
235+
}}),
236+
},
237+
};
238+
OutputGraphExprValue o_combine_output =
239+
get_only(b.add_output_graph_node(combine_expr, {o_linear_output}, 1_n));
240+
241+
b.equate_outputs(p_linear_output, o_combine_output);
242+
243+
return b.get_substitution();
152244
}
153245

154246
Substitution create_partition_conv2d_combine(nonnegative_int num_dims,

0 commit comments

Comments
 (0)