Skip to content

Commit f5c49c7

Browse files
author
Victor Li
committed
Adding in some a simple function to get the currently available substritutions
1 parent 030b0e8 commit f5c49c7

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

lib/substitutions/include/substitutions/unity_substitution_set.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNITY_SUBSTITUTION_SET_H
22
#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNITY_SUBSTITUTION_SET_H
33

4+
#include "pcg/machine_specification.dtg.h"
45
#include "substitutions/substitution.dtg.h"
6+
#include "utils/fmt/vector.h"
57

68
namespace FlexFlow {
79

10+
std::vector<Substitution>
11+
get_substitution_set(MachineSpecification const &resources);
12+
813
Substitution create_combine_inception(int num_convs, int num_dims, int degree);
914
Substitution create_combine_concat(int num_inputs, int num_dims, int degree);
10-
Substitution create_replicate_linear_combine(int num_dims,
11-
int degree,
12-
Activation activation,
13-
bool use_bias);
15+
Substitution
16+
create_replicate_linear_combine(int num_dims, int degree, bool use_bias);
1417
Substitution create_partition_linear_combine(int num_dims,
1518
int degree,
1619
Activation activation,
@@ -27,6 +30,7 @@ Substitution create_partition_concat_combine(int num_inputs,
2730
Substitution create_partition_softmax_combine(ff_dim_t softmax_dim,
2831
ff_dim_t partition_dim,
2932
int degree);
33+
Substitution create_fuse_linear_activation(Activation activation);
3034

3135
} // namespace FlexFlow
3236

lib/substitutions/src/substitutions/unity_substitution_set.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@
77

88
namespace FlexFlow {
99

10+
std::vector<Substitution>
11+
get_substitution_set(MachineSpecification const &resources) {
12+
std::vector<Substitution> substitutions;
13+
for (int num_dims = 1; num_dims <= MAX_TENSOR_DIM; num_dims++) {
14+
for (int degree = 1; degree <= resources.num_nodes; degree *= 2) {
15+
substitutions.push_back(
16+
create_replicate_linear_combine(num_dims, degree, true));
17+
substitutions.push_back(
18+
create_replicate_linear_combine(num_dims, degree, false));
19+
}
20+
}
21+
substitutions.push_back(create_fuse_linear_activation(Activation::RELU));
22+
substitutions.push_back(create_fuse_linear_activation(Activation::SIGMOID));
23+
substitutions.push_back(create_fuse_linear_activation(Activation::TANH));
24+
substitutions.push_back(create_fuse_linear_activation(Activation::GELU));
25+
return substitutions;
26+
}
27+
1028
Substitution create_combine_inception(int num_convs, int num_dims, int degree) {
1129
NOT_IMPLEMENTED();
1230
}
@@ -122,11 +140,6 @@ Substitution create_partition_linear_combine(int num_dims,
122140
NOT_IMPLEMENTED();
123141
}
124142

125-
Substitution create_partition_linear_combine(int degree,
126-
Activation activation) {
127-
NOT_IMPLEMENTED();
128-
}
129-
130143
Substitution create_partition_conv2d_combine(int num_dims, int degree) {
131144
NOT_IMPLEMENTED();
132145
}

0 commit comments

Comments
 (0)