Skip to content

Commit 79a3d19

Browse files
committed
readding mcmc
1 parent 419a8b1 commit 79a3d19

20 files changed

+656
-1
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H
2+
#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H
3+
4+
#include "pcg/machine_specification.dtg.h"
5+
#include "pcg/machine_view.dtg.h"
6+
#include "pcg/operator_task_space.dtg.h"
7+
8+
namespace FlexFlow {
9+
10+
bool is_valid_machine_view(MachineView const &mv,
11+
OperatorTaskSpace const &task,
12+
MachineSpecification const &ms);
13+
14+
std::unordered_set<MachineView>
15+
get_allowed_machine_views(MachineSpecification const &machine_spec,
16+
OperatorTaskSpace const &task,
17+
DeviceType device_type);
18+
19+
} // namespace FlexFlow
20+
21+
#endif
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H
2+
#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H
3+
4+
#include "compiler/search_result.dtg.h"
5+
#include "substitutions/pcg_pattern_match.dtg.h"
6+
#include "substitutions/sub_parallel_computation_graph.dtg.h"
7+
#include "substitutions/substitution.dtg.h"
8+
9+
namespace FlexFlow {
10+
/**
11+
* @brief Applies \p substitution to \p mapped_pcg at the location specified by
12+
* \p match, returning the resulting SearchResult (mapped pcg)
13+
*
14+
* @param mapped_pcg
15+
* @param substitution
16+
* @param match The location at which to apply substitution. This location in
17+
* sub_pcg should match substitution's PCGPattern. Likely created by running
18+
* FlexFlow::find_pattern_matches(PCGPattern const &,
19+
* SubParallelComputationGraph const &).
20+
* @return SearchResult A mapped pcg similar to mapped_pcg, but with
21+
* the subgraph of the pcg specified by match replaced with the result of the
22+
* output expression of substitution and the machine mapping updated to account
23+
* for the new output
24+
*/
25+
SearchResult apply_substitution_and_update_machine_mapping(
26+
SearchResult const &mapped_pcg,
27+
Substitution const &sub,
28+
PCGPatternMatch const &match);
29+
30+
} // namespace FlexFlow
31+
32+
#endif
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H
2+
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H
3+
4+
#include "compiler/machine_mapping/machine_mapping.h"
5+
#include "compiler/search_result.dtg.h"
6+
7+
namespace FlexFlow {
8+
std::optional<MachineMapping>
9+
get_naive_mapping(ParallelComputationGraph &pcg,
10+
MachineSpecification const &resources,
11+
DeviceType const &device_type);
12+
13+
std::optional<MachineMapping>
14+
get_random_mutation(SearchResult mapped_pcg,
15+
MachineSpecification const &resource,
16+
DeviceType const &device_type);
17+
} // namespace FlexFlow
18+
19+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H
2+
#define _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H
3+
4+
#include "compiler/cost_estimator/runtime_only_cost_estimator.h"
5+
#include "compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.h"
6+
#include "compiler/search_result.dtg.h"
7+
#include "pcg/computation_graph.h"
8+
#include "pcg/machine_specification.dtg.h"
9+
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"
10+
#include "substitutions/sub_parallel_computation_graph.h"
11+
#include "substitutions/substitution.h"
12+
13+
namespace FlexFlow {
14+
15+
SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg,
16+
RuntimeOnlyCostEstimator const &cost_estimator,
17+
MachineSpecification const &resources,
18+
MCMCOverMappedPCGConfig const &search_config);
19+
20+
} // namespace FlexFlow
21+
22+
#endif
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
namespace = "FlexFlow"
2+
name = "MCMCOverMappedPCGConfig"
3+
features = [
4+
"eq",
5+
"hash",
6+
"fmt",
7+
]
8+
9+
includes = [
10+
"pcg/device_type.dtg.h",
11+
"utils/nonnegative_int/nonnegative_int.h"
12+
]
13+
14+
[[fields]]
15+
name = "temperature"
16+
type = "float"
17+
18+
[[fields]]
19+
name = "num_iterations"
20+
type = "::FlexFlow::nonnegative_int"
21+
22+
[[fields]]
23+
name = "substitution_interval"
24+
type = "::FlexFlow::nonnegative_int"
25+
26+
[[fields]]
27+
name = "device_type"
28+
type = "::FlexFlow::DeviceType"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H
2+
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H
3+
4+
#include "compiler/search_result.dtg.h"
5+
6+
namespace FlexFlow {
7+
8+
std::string format_as(SearchResult const &);
9+
std::ostream &operator<<(std::ostream &, SearchResult const &);
10+
11+
} // namespace FlexFlow
12+
13+
#endif
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
namespace = "FlexFlow"
2+
name = "SearchResult"
3+
features = [
4+
]
5+
6+
includes = [
7+
"pcg/parallel_computation_graph/parallel_computation_graph.h",
8+
"compiler/machine_mapping/machine_mapping.h",
9+
]
10+
11+
[[fields]]
12+
name = "pcg"
13+
type = "::FlexFlow::ParallelComputationGraph"
14+
15+
[[fields]]
16+
name = "machine_mapping"
17+
type = "::FlexFlow::MachineMapping"

lib/compiler/src/compiler/allowed_machine_views.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ static std::unordered_set<MachineView>
5757
product(transform(tensor_dims, [](positive_int num_devices) {
5858
return nonnegative_int{num_devices.int_from_positive_int() - 1};
5959
}));
60+
min_num_devices_with_full_stride_volume =
61+
std::max(min_num_devices_with_full_stride_volume, 1_n);
6062
return ceildiv(total_devices,
6163
positive_int{min_num_devices_with_full_stride_volume});
6264
};
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h"
2+
#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h"
3+
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"
4+
#include "substitutions/apply_substitution/apply_substitution.h"
5+
#include "substitutions/apply_substitution/evaluate_substitution_output.h"
6+
#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h"
7+
#include "substitutions/open_parallel_tensor_guid_t.h"
8+
#include "substitutions/pcg_pattern_match.h"
9+
#include "substitutions/sub_parallel_computation_graph.h"
10+
#include "substitutions/sub_parallel_computation_graph_data.dtg.h"
11+
#include "substitutions/sub_parallel_computation_graph_edge.h"
12+
#include "utils/containers/is_subseteq_of.h"
13+
#include "utils/containers/keys.h"
14+
#include "utils/containers/merge_maps.h"
15+
#include "utils/containers/restrict_keys.h"
16+
#include "utils/containers/set_minus.h"
17+
#include "utils/containers/values.h"
18+
19+
namespace FlexFlow {
20+
21+
SearchResult apply_substitution_and_update_machine_mapping(
22+
SearchResult const &mapped_pcg,
23+
Substitution const &sub,
24+
PCGPatternMatch const &match) {
25+
SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg);
26+
27+
auto substitution_output_result =
28+
evaluate_substitution_output(spcg, sub, match);
29+
SubParallelComputationGraph substitution_output_graph =
30+
substitution_output_result.first;
31+
OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping =
32+
substitution_output_result.second;
33+
34+
SubParallelComputationGraphData output_graph_data =
35+
get_sub_pcg_data(substitution_output_graph);
36+
SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg);
37+
38+
std::unordered_set<parallel_layer_guid_t> pre_nodes =
39+
keys(pre_data.node_data);
40+
std::unordered_set<parallel_layer_guid_t> matched_nodes =
41+
unordered_set_of(values(match.node_assignment));
42+
std::unordered_set<parallel_layer_guid_t> post_nodes_from_original_graph =
43+
set_minus(pre_nodes, matched_nodes);
44+
45+
std::unordered_map<parallel_layer_guid_t, MachineView> machine_views =
46+
mapped_pcg.machine_mapping.machine_views;
47+
48+
std::unordered_set<MachineView> substituted_machine_views =
49+
transform(matched_nodes, [&](parallel_layer_guid_t const &node) {
50+
return machine_views.at(node);
51+
});
52+
MachineView first_substituted_machine_view =
53+
*substituted_machine_views.begin();
54+
55+
std::unordered_map<parallel_layer_guid_t, ParallelLayerAttrs> post_node_data =
56+
[&] {
57+
std::unordered_map<parallel_layer_guid_t, ParallelLayerAttrs>
58+
post_node_data_from_orig = restrict_keys(
59+
pre_data.node_data, post_nodes_from_original_graph);
60+
std::unordered_map<parallel_layer_guid_t, ParallelLayerAttrs>
61+
post_node_data_from_sub = output_graph_data.node_data;
62+
63+
for (auto [layer, attrs] : post_node_data_from_sub) {
64+
machine_views.insert_or_assign(layer, first_substituted_machine_view);
65+
}
66+
67+
return merge_disjoint_maps(post_node_data_from_orig,
68+
post_node_data_from_sub);
69+
}();
70+
71+
std::unordered_set<SubParallelComputationGraphEdge> post_edges = [&] {
72+
std::unordered_set<SubParallelComputationGraphEdge> post_edges_from_orig =
73+
filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) {
74+
if (e.raw_edge.has<DataflowInputEdge>()) {
75+
return true;
76+
} else {
77+
DataflowEdge dfe = e.raw_edge.get<DataflowEdge>();
78+
parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node};
79+
parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node};
80+
return !(contains(matched_nodes, src) ||
81+
contains(matched_nodes, dst));
82+
}
83+
});
84+
85+
std::unordered_set<SubParallelComputationGraphEdge> post_edges_from_sub =
86+
filter(output_graph_data.edges,
87+
[&](SubParallelComputationGraphEdge const &e) {
88+
return !e.raw_edge.has<DataflowInputEdge>();
89+
});
90+
91+
bidict<PatternNodeOutput, parallel_tensor_guid_t>
92+
output_orig_pattern_mapping = get_output_mapping_for_pcg_pattern_match(
93+
match, sub.pcg_pattern, spcg);
94+
bidict<parallel_tensor_guid_t, OutputGraphExprNodeOutput>
95+
output_post_outexpr_mapping = get_output_graph_expr_output_mapping(
96+
output_expr_to_result_sub_pcg_mapping,
97+
sub.output_graph_expr,
98+
substitution_output_graph);
99+
100+
std::unordered_set<SubParallelComputationGraphEdge> incoming_to_sub_edges;
101+
for (auto const &[pattern_input, base_graph_tensor] :
102+
match.input_assignment) {
103+
OutputGraphExprInput output_expr_input =
104+
sub.inputs_mapping.at_l(pattern_input);
105+
input_parallel_tensor_guid_t output_graph_input =
106+
output_expr_to_result_sub_pcg_mapping.input_mapping.at_r(
107+
output_expr_input);
108+
std::unordered_set<parallel_tensor_use_t> uses = get_parallel_tensor_uses(
109+
substitution_output_graph,
110+
open_parallel_tensor_guid_from_input(output_graph_input));
111+
for (parallel_tensor_use_t const &use : uses) {
112+
SubParallelComputationGraphEdge new_edge =
113+
subpcg_edge_from_tensor_and_use(base_graph_tensor, use);
114+
incoming_to_sub_edges.insert(new_edge);
115+
}
116+
}
117+
118+
std::unordered_set<SubParallelComputationGraphEdge> outgoing_from_sub_edges;
119+
for (ParallelComputationGraphEdge const &outgoing_edge :
120+
get_subgraph_outgoing_edges(spcg, matched_nodes)) {
121+
parallel_tensor_guid_t original_tensor =
122+
get_parallel_tensor(outgoing_edge);
123+
PatternNodeOutput pattern_tensor =
124+
output_orig_pattern_mapping.at_r(original_tensor);
125+
OutputGraphExprNodeOutput output_graph_tensor =
126+
sub.outputs_mapping.at_l(pattern_tensor);
127+
parallel_tensor_guid_t new_tensor =
128+
output_post_outexpr_mapping.at_r(output_graph_tensor);
129+
130+
SubParallelComputationGraphEdge new_edge =
131+
subpcg_edge_from_tensor_and_dst(
132+
new_tensor,
133+
get_dst_layer(outgoing_edge),
134+
get_dst_layer_input_idx(outgoing_edge));
135+
outgoing_from_sub_edges.insert(new_edge);
136+
}
137+
138+
return set_union(std::vector{
139+
post_edges_from_orig,
140+
post_edges_from_sub,
141+
incoming_to_sub_edges,
142+
outgoing_from_sub_edges,
143+
});
144+
}();
145+
146+
std::unordered_set<input_parallel_tensor_guid_t> post_inputs =
147+
pre_data.inputs;
148+
149+
std::unordered_map<open_parallel_tensor_guid_t, ParallelTensorAttrs>
150+
post_value_data = [&] {
151+
std::unordered_map<open_parallel_tensor_guid_t, ParallelTensorAttrs>
152+
post_value_data_from_orig = filter_keys(
153+
pre_data.value_data, [&](open_parallel_tensor_guid_t const &t) {
154+
return visit_open_parallel_tensor_guid(
155+
t,
156+
overload{
157+
[&](parallel_tensor_guid_t const &t) {
158+
return contains(post_nodes_from_original_graph,
159+
get_source_layer(t));
160+
},
161+
[](input_parallel_tensor_guid_t const &) {
162+
return true;
163+
},
164+
});
165+
});
166+
167+
std::unordered_map<open_parallel_tensor_guid_t, ParallelTensorAttrs>
168+
post_value_data_from_sub = output_graph_data.value_data;
169+
return merge_disjoint_maps(post_value_data_from_orig,
170+
post_value_data_from_sub);
171+
}();
172+
173+
SubParallelComputationGraphData post_data = SubParallelComputationGraphData{
174+
post_node_data,
175+
post_edges,
176+
post_inputs,
177+
post_value_data,
178+
};
179+
180+
assert(is_subseteq_of(keys(post_node_data), keys(machine_views)));
181+
182+
for (auto it = machine_views.begin(); it != machine_views.end();) {
183+
if (post_node_data.find(it->first) == post_node_data.end()) {
184+
it = machine_views.erase(it);
185+
} else {
186+
++it;
187+
}
188+
}
189+
190+
assert(keys(post_node_data) == keys(machine_views));
191+
192+
return SearchResult{
193+
pcg_from_sub_pcg_by_dropping_inputs(sub_pcg_from_graph_data(post_data)),
194+
MachineMapping{machine_views}};
195+
}
196+
197+
} // namespace FlexFlow

0 commit comments

Comments
 (0)