@@ -36,6 +36,7 @@ void Policy::get_policy(const std::vector<double> &q_values, const std::vector<d
3636 throw std::invalid_argument (" No valid q-values" );
3737 }
3838 if (valid_count == 1 ) {
39+ std::fill (result.begin (), result.end (), MIN_FLOAT);
3940 result[valid_indices[0 ]] = 1 ;
4041 return ;
4142 }
@@ -52,7 +53,11 @@ void Policy::get_policy(const std::vector<double> &q_values, const std::vector<d
5253 bool is_nan = std::any_of (result.begin (), result.end (), [](double d) { return std::isnan (d); });
5354 assert (!is_nan);
5455 assert (q_values.size () == result.size ());
55- double sum = std::accumulate (result.begin (), result.end (), 0.0 );
56+ double sum = 0.0 ;
57+ for (size_t i = 0 ; i < result.size (); i++) {
58+ if (result[i] > MIN_FLOAT)
59+ sum += result[i];
60+ }
5661 assert (sum > 0.99 && sum < 1.01 );
5762}
5863
@@ -68,10 +73,10 @@ void Policy::alpha_zero(const std::vector<double> &q_values, const std::vector<d
6873 if (pi_values[i] > MIN_FLOAT && q_values[i] > MIN_FLOAT) {
6974 scores[i] = q_values[i] + exploration * pi_values[i] * std::sqrt (count_sum) / (1 + counts_d[i]);
7075 valid_count++;
76+ score_sum += scores[i];
7177 } else {
72- scores[i] = 0 ;
78+ scores[i] = MIN_FLOAT ;
7379 }
74- score_sum += scores[i];
7580 }
7681 assert (valid_count > 0 );
7782 // If score sum is 0, simply return the uniform distribution over valid actions
@@ -80,14 +85,17 @@ void Policy::alpha_zero(const std::vector<double> &q_values, const std::vector<d
8085 if (q_values[i] > MIN_FLOAT && pi_values[i] > MIN_FLOAT) {
8186 result[i] = 1.0 / static_cast <double >(valid_count);
8287 } else {
83- result[i] = 0 ;
88+ result[i] = MIN_FLOAT ;
8489 }
8590 }
8691 return ;
8792 }
8893 // Normalize the scores
8994 for (size_t i = 0 ; i < q_values.size (); i++) {
90- result[i] = scores[i] / score_sum;
95+ if (scores[i] > MIN_FLOAT)
96+ result[i] = scores[i] / score_sum;
97+ else
98+ result[i] = MIN_FLOAT;
9199 }
92100}
93101
@@ -108,7 +116,7 @@ double Policy::find_rpo_alpha(double alpha_min, double alpha_max, const std::vec
108116 }
109117 pi_difference_sum += scaled_pi_values[i] / diff;
110118 }
111- if ((pi_difference_sum - 1 ) < TOLERANCE) {
119+ if (std::abs (pi_difference_sum - 1 ) < TOLERANCE) {
112120 return alpha_mid;
113121 }
114122 if (pi_difference_sum > 1 )
@@ -133,7 +141,7 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
133141 q_sum += q_values[i];
134142 valid_count++;
135143 } else {
136- result[i] = 0 ;
144+ result[i] = MIN_FLOAT ;
137145 }
138146 }
139147 // If q sum is 0, simply return the uniform distribution over valid actions
@@ -143,21 +151,25 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
143151 if (q_values[i] > MIN_FLOAT && pi_values[i] > MIN_FLOAT) {
144152 result[i] = 1.0 / static_cast <double >(valid_count);
145153 } else {
146- result[i] = 0 ;
154+ result[i] = MIN_FLOAT ;
147155 }
148156 }
149157 return ;
150158 }
151159
152160 for (size_t i = 0 ; i < q_values.size (); i++) {
153- result[i] /= q_sum;
161+ if (q_values[i] > MIN_FLOAT) {
162+ result[i] = result[i] / q_sum;
163+ } else {
164+ result[i] = MIN_FLOAT;
165+ }
154166 }
155167 return ;
156168 }
157169
158170 std::vector<double > scaled_pi_values (q_values.size ());
159171
160- double alpha_min, alpha_max = 0 ;
172+ double alpha_min = 0 , alpha_max = 0 ;
161173
162174 for (size_t i = 0 ; i < q_values.size (); i++) {
163175 scaled_pi_values[i] = pi_values[i] * multiplier;
@@ -167,11 +179,19 @@ void Policy::mcts_rpo(const std::vector<double> &q_values, const std::vector<dou
167179 double alpha = find_rpo_alpha (alpha_min, alpha_max, q_values, scaled_pi_values);
168180 double result_sum = 0 ;
169181 for (size_t i = 0 ; i < q_values.size (); i++) {
170- result[i] = scaled_pi_values[i] / std::max ((alpha - q_values[i]), EPSILON);
171- result_sum += result[i];
182+ if (q_values[i] > MIN_FLOAT) {
183+ result[i] = scaled_pi_values[i] / std::max ((alpha - q_values[i]), EPSILON);
184+ result_sum += result[i];
185+ }
186+ else {
187+ result[i] = MIN_FLOAT;
188+ }
172189 }
173190 for (size_t i = 0 ; i < q_values.size (); i++) {
174- result[i] /= result_sum;
191+ if (result[i] > MIN_FLOAT)
192+ result[i] = result[i] / result_sum;
193+ else
194+ result[i] = MIN_FLOAT;
175195 }
176196}
177197
0 commit comments