Skip to content

Commit 392d8bc

Browse files
author
Victor Li
committed
Change OpCostMetrics.memory to be a nonnegative_int
1 parent 030bfd6 commit 392d8bc

File tree

7 files changed

+466
-24
lines changed

7 files changed

+466
-24
lines changed

lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ features = [
77
]
88

99
includes = [
10+
"utils/nonnegative_int/nonnegative_int.h"
1011
]
1112

1213
[[fields]]
@@ -15,4 +16,4 @@ type = "float"
1516

1617
[[fields]]
1718
name = "memory"
18-
type = "size_t"
19+
type = "::FlexFlow::nonnegative_int"

lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ TEST_SUITE(FF_TEST_SUITE) {
146146

147147
auto map1 = std::unordered_map<OpCostEstimateKey, OpCostMetrics>{{
148148
{map_unmapped_op_cost_estimate_key(k1, mv1),
149-
OpCostMetrics{/*runtime=*/1.0, /*memory=*/0}},
149+
OpCostMetrics{/*runtime=*/1.0, /*memory=*/nonnegative_int{0}}},
150150
{map_unmapped_op_cost_estimate_key(k2, mv1),
151-
OpCostMetrics{/*runtime=*/2.0, /*memory=*/0}},
151+
OpCostMetrics{/*runtime=*/2.0, /*memory=*/nonnegative_int{0}}},
152152
{map_unmapped_op_cost_estimate_key(k1, mv2),
153-
OpCostMetrics{/*runtime=*/1.5, /*memory=*/0}},
153+
OpCostMetrics{/*runtime=*/1.5, /*memory=*/nonnegative_int{0}}},
154154
{map_unmapped_op_cost_estimate_key(k2, mv2),
155-
OpCostMetrics{/*runtime=*/2.5, /*memory=*/0}},
155+
OpCostMetrics{/*runtime=*/2.5, /*memory=*/nonnegative_int{0}}},
156156
}};
157157

158158
CostEstimator cost_estimator = make_fake_cost_estimator(

lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ TEST_SUITE(FF_TEST_SUITE) {
146146

147147
CostEstimator cost_estimator = make_fake_cost_estimator(
148148
std::unordered_map<OpCostEstimateKey, OpCostMetrics>{{
149-
{map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, 2}},
150-
{map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, 3}},
151-
{map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, 1}},
152-
{map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, 2}},
149+
{map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, nonnegative_int{2}}},
150+
{map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, nonnegative_int{3}}},
151+
{map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, nonnegative_int{1}}},
152+
{map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, nonnegative_int{2}}},
153153
}},
154154
std::unordered_map<TensorSetMovement, float>{{
155155
{TensorSetMovement{{}}, 0.0},
@@ -183,13 +183,13 @@ TEST_SUITE(FF_TEST_SUITE) {
183183
cache, context, problem_tree, full_machine_spec, constraints);
184184
MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{
185185
MachineMappingForSingleLayer{
186-
OpCostMetrics{1.0, 2},
186+
OpCostMetrics{1.0, nonnegative_int{2}},
187187
ParallelLayerGuidObliviousMachineMapping{{
188188
{binary_tree_root_path(), mv1},
189189
}},
190190
},
191191
MachineMappingForSingleLayer{
192-
OpCostMetrics{1.5, 1},
192+
OpCostMetrics{1.5, nonnegative_int{1}},
193193
ParallelLayerGuidObliviousMachineMapping{{
194194
{binary_tree_root_path(), mv2},
195195
}},
@@ -214,7 +214,7 @@ TEST_SUITE(FF_TEST_SUITE) {
214214
MachineMappingForSingleLayer{
215215
OpCostMetrics{
216216
/*runtime=*/1.0 + 2.0 + 0.1,
217-
/*memory=*/2 + 3,
217+
/*memory=*/nonnegative_int{2 + 3},
218218
},
219219
ParallelLayerGuidObliviousMachineMapping{{
220220
{
@@ -232,7 +232,7 @@ TEST_SUITE(FF_TEST_SUITE) {
232232
}},
233233
},
234234
MachineMappingForSingleLayer{
235-
OpCostMetrics{1.5 + 2.5 + 0.1, 1 + 2},
235+
OpCostMetrics{1.5 + 2.5 + 0.1, nonnegative_int{1 + 2}},
236236
ParallelLayerGuidObliviousMachineMapping{{
237237
{
238238
BinaryTreePath{{
@@ -266,7 +266,7 @@ TEST_SUITE(FF_TEST_SUITE) {
266266
cache, context, problem_tree, full_machine_spec, constraints);
267267
MachineMappingWithMemoryResult correct =
268268
MachineMappingWithMemoryResult{{MachineMappingForSingleLayer{
269-
OpCostMetrics{2.5, 2},
269+
OpCostMetrics{2.5, nonnegative_int{2}},
270270
ParallelLayerGuidObliviousMachineMapping{{
271271
{
272272
BinaryTreePath{{

lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ TEST_SUITE(FF_TEST_SUITE) {
5353

5454
OpCostMetrics cost1 = OpCostMetrics{
5555
/*runtime=*/2.0,
56-
/*memory=*/2,
56+
/*memory=*/nonnegative_int{2},
5757
};
5858
OpCostMetrics cost2 = OpCostMetrics{
5959
/*runtime=*/4.0,
60-
/*memory=*/1,
60+
/*memory=*/nonnegative_int{1},
6161
};
6262
OpCostMetrics cost3 = OpCostMetrics{
6363
/*runtime=*/2.0,
64-
/*memory=*/3,
64+
/*memory=*/nonnegative_int{3},
6565
};
6666

6767
MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{
@@ -183,7 +183,7 @@ TEST_SUITE(FF_TEST_SUITE) {
183183

184184
OpCostMetrics pre_cost = OpCostMetrics{
185185
/*runtime=*/2.0,
186-
/*memory=*/2,
186+
/*memory=*/nonnegative_int{2},
187187
};
188188
MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{
189189
MachineMappingForSingleLayer{
@@ -209,7 +209,7 @@ TEST_SUITE(FF_TEST_SUITE) {
209209

210210
OpCostMetrics post_cost = OpCostMetrics{
211211
/*runtime=*/4.0,
212-
/*memory=*/1,
212+
/*memory=*/nonnegative_int{1},
213213
};
214214

215215
MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{
@@ -378,7 +378,7 @@ TEST_SUITE(FF_TEST_SUITE) {
378378

379379
OpCostMetrics lhs_cost = OpCostMetrics{
380380
/*runtime=*/2.0,
381-
/*memory=*/2,
381+
/*memory=*/nonnegative_int{2},
382382
};
383383
MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{
384384
MachineMappingForSingleLayer{
@@ -404,7 +404,7 @@ TEST_SUITE(FF_TEST_SUITE) {
404404

405405
OpCostMetrics rhs_cost = OpCostMetrics{
406406
/*runtime=*/4.0,
407-
/*memory=*/1,
407+
/*memory=*/nonnegative_int{1},
408408
};
409409
MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{
410410
MachineMappingForSingleLayer{
@@ -519,15 +519,15 @@ TEST_SUITE(FF_TEST_SUITE) {
519519

520520
OpCostMetrics cost1 = OpCostMetrics{
521521
/*runtime=*/2.0,
522-
/*memory=*/2,
522+
/*memory=*/nonnegative_int{2},
523523
};
524524
OpCostMetrics cost2 = OpCostMetrics{
525525
/*runtime=*/4.0,
526-
/*memory=*/1,
526+
/*memory=*/nonnegative_int{1},
527527
};
528528
OpCostMetrics cost3 = OpCostMetrics{
529529
/*runtime=*/2.0,
530-
/*memory=*/3,
530+
/*memory=*/nonnegative_int{3},
531531
};
532532

533533
MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H
2+
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H
3+
4+
#include "rapidcheck.h"
5+
6+
#include <any>
7+
#include <fmt/format.h>
8+
#include <functional>
9+
#include <nlohmann/json.hpp>
10+
#include <string>
11+
12+
namespace FlexFlow {
13+
class nonnegative_int {
14+
public:
15+
nonnegative_int() = delete;
16+
explicit nonnegative_int(int value);
17+
18+
explicit operator int() const noexcept;
19+
20+
bool operator<(nonnegative_int const &other) const;
21+
bool operator==(nonnegative_int const &other) const;
22+
bool operator>(nonnegative_int const &other) const;
23+
bool operator<=(nonnegative_int const &other) const;
24+
bool operator!=(nonnegative_int const &other) const;
25+
bool operator>=(nonnegative_int const &other) const;
26+
27+
bool operator<(int const &other) const;
28+
bool operator==(int const &other) const;
29+
bool operator>(int const &other) const;
30+
bool operator<=(int const &other) const;
31+
bool operator!=(int const &other) const;
32+
bool operator>=(int const &other) const;
33+
34+
friend bool operator<(int const &lhs, nonnegative_int const &rhs);
35+
friend bool operator==(int const &lhs, nonnegative_int const &rhs);
36+
friend bool operator>(int const &lhs, nonnegative_int const &rhs);
37+
friend bool operator<=(int const &lhs, nonnegative_int const &rhs);
38+
friend bool operator!=(int const &lhs, nonnegative_int const &rhs);
39+
friend bool operator>=(int const &lhs, nonnegative_int const &rhs);
40+
41+
nonnegative_int operator+(nonnegative_int const &other) const;
42+
43+
friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n);
44+
45+
friend int format_as(nonnegative_int const &);
46+
47+
int get_value() const;
48+
49+
private:
50+
int value_;
51+
};
52+
} // namespace FlexFlow
53+
54+
namespace nlohmann {
55+
template <>
56+
struct adl_serializer<::FlexFlow::nonnegative_int> {
57+
static ::FlexFlow::nonnegative_int from_json(json const &j);
58+
static void to_json(json &j, ::FlexFlow::nonnegative_int t);
59+
};
60+
} // namespace nlohmann
61+
62+
namespace std {
63+
template <>
64+
struct hash<::FlexFlow::nonnegative_int> {
65+
std::size_t operator()(FlexFlow::nonnegative_int const &n) const noexcept;
66+
};
67+
} // namespace std
68+
69+
#endif
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#include "utils/nonnegative_int/nonnegative_int.h"
2+
3+
namespace FlexFlow {
4+
5+
nonnegative_int::nonnegative_int(int value) {
6+
if (value < 0) {
7+
throw std::invalid_argument(
8+
"Value of nonnegative_int type must be nonnegative.");
9+
}
10+
this->value_ = value;
11+
}
12+
13+
nonnegative_int::operator int() const noexcept {
14+
return this->value_;
15+
}
16+
17+
bool nonnegative_int::operator<(nonnegative_int const &other) const {
18+
return this->value_ < other.value_;
19+
}
20+
bool nonnegative_int::operator==(nonnegative_int const &other) const {
21+
return this->value_ == other.value_;
22+
}
23+
bool nonnegative_int::operator>(nonnegative_int const &other) const {
24+
return this->value_ > other.value_;
25+
}
26+
bool nonnegative_int::operator<=(nonnegative_int const &other) const {
27+
return this->value_ <= other.value_;
28+
}
29+
bool nonnegative_int::operator!=(nonnegative_int const &other) const {
30+
return this->value_ != other.value_;
31+
}
32+
bool nonnegative_int::operator>=(nonnegative_int const &other) const {
33+
return this->value_ >= other.value_;
34+
}
35+
36+
bool nonnegative_int::operator<(int const &other) const {
37+
return this->value_ < other;
38+
}
39+
bool nonnegative_int::operator==(int const &other) const {
40+
return this->value_ == other;
41+
}
42+
bool nonnegative_int::operator>(int const &other) const {
43+
return this->value_ > other;
44+
}
45+
bool nonnegative_int::operator<=(int const &other) const {
46+
return this->value_ <= other;
47+
}
48+
bool nonnegative_int::operator!=(int const &other) const {
49+
return this->value_ != other;
50+
}
51+
bool nonnegative_int::operator>=(int const &other) const {
52+
return this->value_ >= other;
53+
}
54+
55+
bool operator<(int const &lhs, nonnegative_int const &rhs) {
56+
return lhs < rhs.value_;
57+
}
58+
bool operator==(int const &lhs, nonnegative_int const &rhs) {
59+
return lhs == rhs.value_;
60+
}
61+
bool operator>(int const &lhs, nonnegative_int const &rhs) {
62+
return lhs > rhs.value_;
63+
}
64+
bool operator<=(int const &lhs, nonnegative_int const &rhs) {
65+
return lhs <= rhs.value_;
66+
}
67+
bool operator!=(int const &lhs, nonnegative_int const &rhs) {
68+
return lhs != rhs.value_;
69+
}
70+
bool operator>=(int const &lhs, nonnegative_int const &rhs) {
71+
return lhs >= rhs.value_;
72+
}
73+
74+
nonnegative_int nonnegative_int::operator+(nonnegative_int const &other) const {
75+
return nonnegative_int{this->value_ + other.value_};
76+
}
77+
78+
std::ostream &operator<<(std::ostream &os, nonnegative_int const &n) {
79+
os << n.value_;
80+
return os;
81+
}
82+
83+
int nonnegative_int::get_value() const {
84+
return this->value_;
85+
}
86+
87+
int format_as(nonnegative_int const &x) {
88+
return x.get_value();
89+
}
90+
} // namespace FlexFlow
91+
92+
namespace nlohmann {
93+
::FlexFlow::nonnegative_int
94+
adl_serializer<::FlexFlow::nonnegative_int>::from_json(json const &j) {
95+
return ::FlexFlow::nonnegative_int{j.template get<int>()};
96+
}
97+
98+
void adl_serializer<::FlexFlow::nonnegative_int>::to_json(
99+
json &j, ::FlexFlow::nonnegative_int t) {
100+
j = t.get_value();
101+
}
102+
} // namespace nlohmann
103+
104+
namespace std {
105+
std::size_t hash<::FlexFlow::nonnegative_int>::operator()(
106+
FlexFlow::nonnegative_int const &n) const noexcept {
107+
return std::hash<int>{}(n.get_value());
108+
}
109+
} // namespace std

0 commit comments

Comments
 (0)