Skip to content

Commit 4474cff

Browse files
authored
Fix policies minfloat (#22)
* Add policy minfloat * Set alpha min = 0 * Add minfloat handling for RPO * Fix pi difference * Fix logarithmic policy
1 parent ac4a980 commit 4474cff

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

src/graph/htps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,13 @@ Simulation HTPS::find_leaves_to_expand(std::vector<std::shared_ptr<theorem>> &te
891891
tactic_id = std::distance(node_policy.begin(), std::max_element(node_policy.begin(), node_policy.end()));
892892
} else {
893893
// Normal softmax with temperature, i.e. exp(p / temperature)
894+
// But take logarithm of policy first, as done in evariste
895+
for (size_t i = 0; i < node_policy.size(); i++) {
896+
if (node_policy[i] > MIN_FLOAT)
897+
node_policy[i] = std::log(node_policy[i]);
898+
else
899+
node_policy[i] = MIN_FLOAT;
900+
}
894901
double p_sum = 0;
895902
for (auto &p: node_policy) {
896903
p = std::exp(p / params.policy_temperature);

src/model/policy.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void Policy::get_policy(const std::vector<double> &q_values, const std::vector<d
3636
throw std::invalid_argument("No valid q-values");
3737
}
3838
if (valid_count == 1) {
39+
std::fill(result.begin(), result.end(), MIN_FLOAT);
3940
result[valid_indices[0]] = 1;
4041
return;
4142
}
@@ -52,7 +53,11 @@ void Policy::get_policy(const std::vector<double> &q_values, const std::vector<d
5253
bool is_nan = std::any_of(result.begin(), result.end(), [](double d) { return std::isnan(d); });
5354
assert (!is_nan);
5455
assert (q_values.size() == result.size());
55-
double sum = std::accumulate(result.begin(), result.end(), 0.0);
56+
double sum = 0.0;
57+
for (size_t i = 0; i < result.size(); i++) {
58+
if (result[i] > MIN_FLOAT)
59+
sum += result[i];
60+
}
5661
assert (sum > 0.99 && sum < 1.01);
5762
}
5863

@@ -68,10 +73,10 @@ void Policy::alpha_zero(const std::vector<double> &q_values, const std::vector<d
6873
if (pi_values[i] > MIN_FLOAT && q_values[i] > MIN_FLOAT) {
6974
scores[i] = q_values[i] + exploration * pi_values[i] * std::sqrt(count_sum) / (1 + counts_d[i]);
7075
valid_count++;
76+
score_sum += scores[i];
7177
} else {
72-
scores[i] = 0;
78+
scores[i] = MIN_FLOAT;
7379
}
74-
score_sum += scores[i];
7580
}
7681
assert (valid_count > 0);
7782
// If score sum is 0, simply return the uniform distribution over valid actions
@@ -80,14 +85,17 @@ void Policy::alpha_zero(const std::vector<double> &q_values, const std::vector<d
8085
if (q_values[i] > MIN_FLOAT && pi_values[i] > MIN_FLOAT) {
8186
result[i] = 1.0 / static_cast<double>(valid_count);
8287
} else {
83-
result[i] = 0;
88+
result[i] = MIN_FLOAT;
8489
}
8590
}
8691
return;
8792
}
8893
// Normalize the scores
8994
for (size_t i = 0; i < q_values.size(); i++) {
90-
result[i] = scores[i] / score_sum;
95+
if (scores[i] > MIN_FLOAT)
96+
result[i] = scores[i] / score_sum;
97+
else
98+
result[i] = MIN_FLOAT;
9199
}
92100
}
93101

@@ -108,7 +116,7 @@ double Policy::find_rpo_alpha(double alpha_min, double alpha_max, const std::vec
108116
}
109117
pi_difference_sum += scaled_pi_values[i] / diff;
110118
}
111-
if ((pi_difference_sum - 1) < TOLERANCE) {
119+
if (std::abs(pi_difference_sum - 1) < TOLERANCE) {
112120
return alpha_mid;
113121
}
114122
if (pi_difference_sum > 1)
@@ -133,7 +141,7 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
133141
q_sum += q_values[i];
134142
valid_count++;
135143
} else {
136-
result[i] = 0;
144+
result[i] = MIN_FLOAT;
137145
}
138146
}
139147
// If q sum is 0, simply return the uniform distribution over valid actions
@@ -143,21 +151,25 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
143151
if (q_values[i] > MIN_FLOAT && pi_values[i] > MIN_FLOAT) {
144152
result[i] = 1.0 / static_cast<double>(valid_count);
145153
} else {
146-
result[i] = 0;
154+
result[i] = MIN_FLOAT;
147155
}
148156
}
149157
return;
150158
}
151159

152160
for (size_t i = 0; i < q_values.size(); i++) {
153-
result[i] /= q_sum;
161+
if (q_values[i] > MIN_FLOAT) {
162+
result[i] = result[i] / q_sum;
163+
} else {
164+
result[i] = MIN_FLOAT;
165+
}
154166
}
155167
return;
156168
}
157169

158170
std::vector<double> scaled_pi_values(q_values.size());
159171

160-
double alpha_min, alpha_max = 0;
172+
double alpha_min = 0, alpha_max = 0;
161173

162174
for (size_t i = 0; i < q_values.size(); i++) {
163175
scaled_pi_values[i] = pi_values[i] * multiplier;
@@ -167,11 +179,19 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
167179
double alpha = find_rpo_alpha(alpha_min, alpha_max, q_values, scaled_pi_values);
168180
double result_sum = 0;
169181
for (size_t i = 0; i < q_values.size(); i++) {
170-
result[i] = scaled_pi_values[i] / std::max((alpha - q_values[i]), EPSILON);
171-
result_sum += result[i];
182+
if (q_values[i] > MIN_FLOAT) {
183+
result[i] = scaled_pi_values[i] / std::max((alpha - q_values[i]), EPSILON);
184+
result_sum += result[i];
185+
}
186+
else {
187+
result[i] = MIN_FLOAT;
188+
}
172189
}
173190
for (size_t i = 0; i < q_values.size(); i++) {
174-
result[i] /= result_sum;
191+
if (result[i] > MIN_FLOAT)
192+
result[i] = result[i] / result_sum;
193+
else
194+
result[i] = MIN_FLOAT;
175195
}
176196
}
177197

0 commit comments

Comments
 (0)