Skip to content

Commit fe6d950

Browse files
author
Victor Li
committed
removing substitution part of MCMC for now
1 parent d70a44a commit fe6d950

File tree

7 files changed

+39
-56
lines changed

7 files changed

+39
-56
lines changed

.envrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
source_up_if_exists
2+
3+
use flake

.vimrc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
" example search path configuration
2+
set path=lib/runtime/**,lib/**
3+
4+
" set build target
5+
" let g:target = "pcg"
6+
7+
" set test target
8+
" let g:test_target = "utils-test"

lib/compiler/src/compiler/mcmc/mcmc_algorithm.cc

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "utils/full_binary_tree/binary_tree_path.h"
2929
#include "utils/graph/node/algorithms.h"
3030
#include "utils/optional.h"
31+
#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h"
32+
3133

3234
namespace FlexFlow {
3335

@@ -52,7 +54,6 @@ SearchResult apply_substitution_and_update_machine_mapping(
5254
SearchResult const &mapped_pcg,
5355
Substitution const &sub,
5456
PCGPatternMatch const &match) {
55-
// std::cout << "applying substitution" << std::endl;
5657
SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg);
5758

5859
auto substitution_output_result =
@@ -217,19 +218,17 @@ std::vector<SearchResult> all_pcgs_obtained_by_applying_a_substitution(
217218
SearchResult const &mapped_pcg,
218219
std::vector<Substitution> const &substitutions) {
219220
std::vector<SearchResult> results;
220-
SubParallelComputationGraph subpcg = sub_pcg_from_full_pcg(mapped_pcg.pcg);
221-
// std::cout << "len" << substitutions.size() << std::endl;
221+
//currently not functional
222+
/*SubParallelComputationGraph subpcg = sub_pcg_from_full_pcg(mapped_pcg.pcg);
222223
for (Substitution const &substitution : substitutions) {
223-
std::cout << "in outer loop" << std::endl;
224224
for (PCGPatternMatch const &pattern_match :
225225
find_pattern_matches(substitution.pcg_pattern, subpcg)) {
226-
std::cout << "getting stuff" << std::endl;
227226
SearchResult mapped_pcg_from_substitution =
228227
apply_substitution_and_update_machine_mapping(
229228
mapped_pcg, substitution, pattern_match);
230229
results.push_back(mapped_pcg_from_substitution);
231230
}
232-
}
231+
}*/
233232
return results;
234233
}
235234

@@ -267,16 +266,9 @@ SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg,
267266

268267
if (current_estimate < best_estimate) {
269268
best_state = current_state;
270-
std::cout << "new best state" << std::endl;
271-
std::cout << current_estimate << std::endl;
272-
std::cout << best_estimate << std::endl;
273269
} else if (current_estimate > best_estimate * search_config.alpha) {
274270
continue;
275-
} else {
276-
std::cout << current_estimate << best_estimate * search_config.alpha
277-
<< std::endl;
278271
}
279-
// std::cout << "Hello" << std::endl;
280272

281273
for (SearchResult const &new_mapped_pcg :
282274
all_pcgs_obtained_by_applying_a_substitution(current_mapped_pcg,
@@ -287,9 +279,6 @@ SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg,
287279
new_mapped_pcg.machine_mapping,
288280
resources);
289281

290-
std::cout << "new substitution" << std::endl;
291-
292-
std::cout << "new estimate" << new_estimate << std::endl;
293282
if (new_estimate <= search_config.threshold &&
294283
get_nodes(new_mapped_pcg.pcg.raw_graph).size() <=
295284
search_config.max_num_ops) {
@@ -304,11 +293,7 @@ SearchResult mcmc_graph_optimize(ParallelComputationGraph &pcg,
304293
cost_estimator,
305294
new_machine_mapping,
306295
resources);
307-
//std::cout << "new mapping" << std::endl;
308-
309-
//std::cout << "new estimate" << new_estimate << std::endl;
310296
if (new_estimate <= search_config.threshold) {
311-
//std::cout << "pushed" << std::endl;
312297
candidates.push(
313298
MCMCOptimizeState{SearchResult{current_mapped_pcg.pcg, new_machine_mapping}, -1 * new_estimate});
314299
}

lib/compiler/test/src/compiler/mcmc/mcmc_algorithm.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,10 @@ TEST_SUITE(FF_TEST_SUITE) {
7373
SearchResult result = mcmc_graph_optimize(
7474
pcg, cost_estimator, full_machine_spec, search_config);
7575

76-
std::cout << task_simulator_estimate_forward_pass_time(result.pcg,
77-
cost_estimator,
78-
result.machine_mapping,
79-
full_machine_spec) << std::endl;
80-
8176
CHECK(task_simulator_estimate_forward_pass_time(result.pcg,
8277
cost_estimator,
8378
result.machine_mapping,
8479
full_machine_spec) < 16);
8580

86-
CHECK(false);
8781
}
8882
}

lib/substitutions/src/substitutions/pcg_pattern.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ std::unordered_set<PatternNode> get_nodes(PCGPattern const &p) {
2323
static MatchAdditionalCriterion
2424
pcg_pattern_criteria(PCGPattern const &pattern,
2525
SubParallelComputationGraph const &pcg) {
26-
// std::cout << "GGETTING pattern criteria" << std::endl;
27-
// std::cout << get_nodes(pattern) << std::endl;
2826
return MatchAdditionalCriterion{
2927
[&](PatternNode const &patternNode, Node const &pcgNode) {
3028
return operator_satisfies_pattern(
@@ -42,8 +40,6 @@ static MatchAdditionalCriterion
4240
std::vector<PCGPatternMatch>
4341
find_pattern_matches(PCGPattern const &pattern,
4442
SubParallelComputationGraph const &pcg) {
45-
46-
// std::cout << "IN PATTERN MATCH"<< std::endl;
4743
std::vector<UnlabelledDataflowGraphPatternMatch> unlabelled_matches =
4844
find_pattern_matches(get_unlabelled_pattern(pattern),
4945
pcg.raw_graph,
@@ -69,20 +65,11 @@ UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) {
6965

7066
TensorAttributePattern get_tensor_pattern(PCGPattern const &p,
7167
PatternValue const &v) {
72-
73-
// std::cout << "get tensor pattern"<< std::endl;
74-
// std::cout << v << std::endl;
75-
// std::cout << raw_open_dataflow_value_from_pattern_value(v) << std::endl;
76-
TensorAttributePattern t =
77-
p.raw_graph.at(raw_open_dataflow_value_from_pattern_value(v));
78-
// std::cout << "hmm" << std::endl;
7968
return p.raw_graph.at(raw_open_dataflow_value_from_pattern_value(v));
8069
}
8170

8271
OperatorAttributePattern get_operator_pattern(PCGPattern const &p,
8372
PatternNode const &n) {
84-
85-
// std::cout << "get op pattern"<< std::endl;
8673
return p.raw_graph.at(n.raw_node);
8774
}
8875

lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,16 @@ std::vector<UnlabelledDataflowGraphPatternMatch>
7171
find_pattern_matches(UnlabelledGraphPattern const &pattern,
7272
OpenDataflowGraphView const &graph,
7373
MatchAdditionalCriterion const &additional_criterion) {
74-
// std::cout << "find pattern matches" << std::endl;
7574
std::vector<UnlabelledDataflowGraphPatternMatch> matches;
7675
if (is_singleton_pattern(pattern)) {
77-
// std::cout << "singleton pattern" << std::endl;
7876
for (Node const &graph_node : get_nodes(graph)) {
79-
// std::cout << "11111" << std::endl;
8077
std::optional<UnlabelledDataflowGraphPatternMatch> candidate =
8178
get_candidate_singleton_match(pattern, graph, graph_node);
82-
// std::cout << "22222" << std::endl;
8379
if (candidate.has_value() &&
8480
unlabelled_pattern_does_match(
8581
pattern, graph, candidate.value(), additional_criterion)) {
86-
// std::cout << "2.555" << std::endl;
8782
matches.push_back(candidate.value());
8883
}
89-
// std::cout << "33333" << std::endl;
9084
}
9185
} else {
9286
PatternSplit split = find_even_split(pattern);
@@ -116,7 +110,6 @@ std::vector<UnlabelledDataflowGraphPatternMatch>
116110
}
117111
}
118112
}
119-
// std::cout << "return from pattern matches" << std::endl;
120113
return matches;
121114
}
122115

lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,26 @@ bool pattern_matches_subgraph_under(
9797
&full_graph_values_to_subgraph_inputs,
9898
UnlabelledDataflowGraphPatternMatch const &match,
9999
MatchAdditionalCriterion const &additional_criterion) {
100-
// std::cout << "pattern amtches subgrpah under" << std::endl;
100+
std::cout << "gamer" << std::endl;
101+
std::cout << get_open_dataflow_values(pattern.raw_graph) << std::endl;
101102
SubgraphConcreteFromPattern concrete_from_pattern{
102103
match, full_graph_values_to_subgraph_inputs};
103104

104105
std::unordered_set<Node> concrete_nodes = get_nodes(subgraph);
105106
std::unordered_set<Node> concrete_nodes_from_match =
106107
transform(get_nodes(pattern), concrete_from_pattern);
107-
// std::cout << "mid of pattern amtches subgrpah under" << std::endl;
108108

109109
if (concrete_nodes != concrete_nodes_from_match) {
110110
return false;
111111
}
112112

113113
for (PatternNode const &pattern_node : get_nodes(pattern)) {
114-
// std::cout << "hello hello hello" << std::endl;
115114
if (!additional_criterion.node_criterion(
116115
pattern_node, concrete_from_pattern(pattern_node))) {
117-
// std::cout << "hello hello hello hello hello" << std::endl;
118116
return false;
119117
}
120118
}
121119

122-
// std::cout << "later mid of pattern amtches subgrpah under" << std::endl;
123-
124120
std::unordered_set<OpenDataflowEdge> concrete_edges = get_edges(subgraph);
125121
std::unordered_set<OpenDataflowEdge> concrete_edge_from_match =
126122
transform(get_edges(pattern), concrete_from_pattern);
@@ -137,14 +133,20 @@ bool pattern_matches_subgraph_under(
137133
if (concrete_values != concrete_values_from_match) {
138134
return false;
139135
}
136+
std::cout << "later later mid of pattern amtches subgrpah under" << std::endl;
137+
140138

141139
for (PatternValue const &pattern_value : get_values(pattern)) {
140+
std::cout << "dfjsahdfkiasjhdfkasjhdfkasdjhdfbgk awerhurgvt " << std::endl;
141+
std::cout << get_open_dataflow_values(pattern.raw_graph) << std::endl;
142+
std::cout << pattern_value << std::endl;
142143
if (!additional_criterion.value_criterion(
143144
pattern_value, concrete_from_pattern(pattern_value))) {
145+
std::cout << "dfjsahdfkiasjhdfkasjhdfkasdjhdfbgk awerhurgvtfwewefewfewf " << std::endl;
144146
return false;
145147
}
146148
}
147-
// std::cout << "end of pattern amtches subgrpah under" << std::endl;
149+
std::cout << "end of pattern amtches subgrpah under" << std::endl;
148150

149151
return true;
150152
}
@@ -154,33 +156,44 @@ bool unlabelled_pattern_does_match(
154156
OpenDataflowGraphView const &graph,
155157
UnlabelledDataflowGraphPatternMatch const &match,
156158
MatchAdditionalCriterion const &additional_criterion) {
157-
// std::cout << "unlabelled_pattern_does_match" << std::endl;
159+
std::cout << "unlabelled_pattern_does_match" << std::endl;
158160

159161
OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match);
160162
OpenDataflowGraphView matched_subgraph = subgraph_result.graph;
161163

162164
assert(left_entries(match.node_assignment) == get_nodes(pattern));
163165
assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph));
164-
// std::cout << "middle of" << std::endl;
166+
std::cout << "middle of" << std::endl;
167+
std::cout << get_open_dataflow_values(pattern.raw_graph) << std::endl;
168+
std::cout << left_entries(match.node_assignment) << std::endl;
169+
std::cout << right_entries(match.node_assignment) << std::endl;
170+
std::cout << get_nodes(pattern) << std::endl;
171+
std::cout << get_nodes(matched_subgraph) << std::endl;
165172

166173
MatchAdditionalCriterion through_subgraph_operation =
167174
MatchAdditionalCriterion{
168175
additional_criterion.node_criterion,
169176
[&](PatternValue const &pv, OpenDataflowValue const &v) {
170177
return v.visit<bool>(overload{
171178
[&](DataflowOutput const &) {
179+
//std::cout << "whefihweoifhewfi" <<std::endl;
172180
return additional_criterion.value_criterion(pv, v);
173181
},
174182
[&](DataflowGraphInput const &subgraph_input) {
183+
//std::cout << "bobobobobob" << std::endl;
175184
OpenDataflowValue full_graph_value =
176185
subgraph_result.full_graph_values_to_subgraph_inputs.at_r(
177186
subgraph_input);
187+
/*std::cout << "ppopopopopopo" << std::endl;
188+
bool ss = additional_criterion.value_criterion(pv,
189+
full_graph_value);
190+
std::cout << "lolololololo" << std::endl;*/
178191
return additional_criterion.value_criterion(pv,
179192
full_graph_value);
180193
}});
181194
},
182195
};
183-
// std::cout << "end of unlabelled_pattern_does_match" << std::endl;
196+
//std::cout << "end of unlabelled_pattern_does_match" << std::endl;
184197

185198
return pattern_matches_subgraph_under(
186199
pattern,

0 commit comments

Comments
 (0)