Skip to content

Commit 12f8717

Browse files
author
Victor Li
committed
Fixing unity substituion tests
1 parent 1c18cdd commit 12f8717

File tree

3 files changed

+318
-17
lines changed

3 files changed

+318
-17
lines changed

lib/substitutions/include/substitutions/unity_substitution_set.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims,
2121
bool use_bias);
2222
Substitution create_partition_linear_combine(nonnegative_int num_dims,
2323
nonnegative_int degree,
24-
Activation activation,
2524
bool use_bias);
2625
Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
2726
nonnegative_int degree);

lib/substitutions/src/substitutions/unity_substitution_set.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ std::vector<Substitution>
2121
create_replicate_linear_combine(num_dims, degree, true));
2222
substitutions.push_back(
2323
create_replicate_linear_combine(num_dims, degree, false));
24+
substitutions.push_back(
25+
create_partition_linear_combine(num_dims, degree, true));
26+
substitutions.push_back(
27+
create_partition_linear_combine(num_dims, degree, false));
2428
}
2529
}
2630
substitutions.push_back(create_fuse_linear_activation(Activation::RELU));
@@ -63,7 +67,6 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims,
6367
op_type_equals_constraint(OperatorType::LINEAR),
6468
op_attr_key_equals(OperatorAttributeKey::BIAS,
6569
OperatorAttributeValue{use_bias}),
66-
6770
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree),
6871
}};
6972

@@ -146,9 +149,7 @@ Substitution create_replicate_linear_combine(nonnegative_int num_dims,
146149

147150
Substitution create_partition_linear_combine(nonnegative_int num_dims,
148151
nonnegative_int degree,
149-
Activation activation,
150152
bool use_bias) {
151-
152153
SubstitutionBuilder b;
153154

154155
auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());

0 commit comments

Comments
 (0)