@@ -63,15 +63,14 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims,
6363 op_type_equals_constraint (OperatorType::LINEAR),
6464 op_attr_key_equals (OperatorAttributeKey::BIAS,
6565 OperatorAttributeValue{use_bias}),
66- op_attr_key_divisible_by (OperatorAttributeKey::OUT_CHANNELS,
67- nonnegative_int{degree}),
66+ op_attr_key_divisible_by (OperatorAttributeKey::OUT_CHANNELS, degree),
6867 }};
6968
70- PatternValue p_linear_output = get_only (b. add_pattern_node (
71- linear_pattern,
72- p_inputs,
73- {tensor_attr_pattern_require_num_dims (nonnegative_int{ num_dims} )},
74- " linear" ));
69+ PatternValue p_linear_output = get_only (
70+ b. add_pattern_node ( linear_pattern,
71+ p_inputs,
72+ {tensor_attr_pattern_require_num_dims (num_dims)},
73+ " linear" ));
7574
7675 OutputOperatorAttrsAssignment replicate_input_expr =
7776 OutputOperatorAttrsAssignment{
@@ -148,7 +147,100 @@ Substitution create_partition_linear_combine(nonnegative_int num_dims,
148147 nonnegative_int degree,
149148 Activation activation,
150149 bool use_bias) {
151- NOT_IMPLEMENTED ();
150+ SubstitutionBuilder b;
151+
152+ auto [p_input, o_input] = b.add_input (tensor_attribute_pattern_match_all ());
153+ auto [p_weight, o_weight] = b.add_input (tensor_attribute_pattern_match_all ());
154+ std::vector<PatternValue> p_inputs = {p_input, p_weight};
155+
156+ std::optional<OutputGraphExprValue> o_bias = std::nullopt ;
157+ if (use_bias) {
158+ std::pair<PatternValue, OutputGraphExprValue> bias =
159+ b.add_input (tensor_attribute_pattern_match_all ());
160+ p_inputs.push_back (bias.first );
161+ o_bias = bias.second ;
162+ }
163+
164+ OperatorAttributePattern linear_pattern = OperatorAttributePattern{{
165+ op_type_equals_constraint (OperatorType::LINEAR),
166+ op_attr_key_equals (OperatorAttributeKey::BIAS,
167+ OperatorAttributeValue{use_bias}),
168+ op_attr_key_divisible_by (OperatorAttributeKey::OUT_CHANNELS, degree),
169+ }};
170+
171+ PatternValue p_linear_output = get_only (
172+ b.add_pattern_node (linear_pattern,
173+ p_inputs,
174+ {tensor_attr_pattern_require_num_dims (num_dims)},
175+ " linear" ));
176+
177+ OutputOperatorAttrsAssignment partition_input_expr =
178+ OutputOperatorAttrsAssignment{
179+ std::nullopt ,
180+ {
181+ set_op_type_attr (OperatorType::REPARTITION),
182+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
183+ OperatorAttributeValue{degree}),
184+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DIM,
185+ OperatorAttributeValue{ff_dim_t {1_n}}),
186+ }};
187+ OutputGraphExprValue o_partition_input_output =
188+ get_only (b.add_output_graph_node (partition_input_expr, {o_input}, 1_n));
189+
190+ OutputOperatorAttrsAssignment replicate_weights_expr =
191+ OutputOperatorAttrsAssignment{
192+ std::nullopt ,
193+ {
194+ set_op_type_attr (OperatorType::REPLICATE),
195+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
196+ OperatorAttributeValue{degree}),
197+ }};
198+ OutputGraphExprValue o_replicate_weights_output = get_only (
199+ b.add_output_graph_node (replicate_weights_expr, {o_weight}, 1_n));
200+
201+ std::vector<OutputGraphExprValue> o_linear_inputs = {
202+ o_partition_input_output, o_replicate_weights_output};
203+
204+ if (use_bias) {
205+ OutputOperatorAttrsAssignment replicate_bias_expr =
206+ OutputOperatorAttrsAssignment{
207+ std::nullopt ,
208+ {
209+ set_op_type_attr (OperatorType::REPLICATE),
210+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
211+ OperatorAttributeValue{degree}),
212+ }};
213+ OutputGraphExprValue o_replicate_bias_output = get_only (
214+ b.add_output_graph_node (replicate_bias_expr, {o_bias.value ()}, 1_n));
215+ o_linear_inputs.push_back (o_replicate_bias_output);
216+ }
217+
218+ OutputOperatorAttrsAssignment linear_expr = OutputOperatorAttrsAssignment{
219+ b.pattern_node_named (" linear" ),
220+ {},
221+ };
222+ OutputGraphExprValue o_linear_output =
223+ get_only (b.add_output_graph_node (linear_expr, o_linear_inputs, 1_n));
224+
225+ OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{
226+ std::nullopt ,
227+ {
228+ set_op_type_attr (OperatorType::COMBINE),
229+ set_attr_to_constant (OperatorAttributeKey::PARALLEL_DEGREE,
230+ OperatorAttributeValue{degree}),
231+ set_attr_to_constant (
232+ OperatorAttributeKey::PARALLEL_DIM,
233+ OperatorAttributeValue{ff_dim_t {
234+ nonnegative_int{num_dims.unwrap_nonnegative () - 1 },
235+ }}),
236+ },
237+ };
238+ OutputGraphExprValue o_combine_output =
239+ get_only (b.add_output_graph_node (combine_expr, {o_linear_output}, 1_n));
240+
241+ b.equate_outputs (p_linear_output, o_combine_output);
242+
243+ return b.get_substitution ();
152244}
153245
154246Substitution create_partition_conv2d_combine (nonnegative_int num_dims,
0 commit comments