-
Notifications
You must be signed in to change notification settings - Fork 190
Expand file tree
/
Copy pathez_tree.pyx
More file actions
158 lines (126 loc) · 7 KB
/
ez_tree.pyx
File metadata and controls
158 lines (126 loc) · 7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# distutils:language=c++
# cython:language_level=3
import cython
from libcpp.vector cimport vector
cdef class MinMaxStatsList:
@cython.binding
def __cinit__(self, int num):
self.cmin_max_stats_lst = new CMinMaxStatsList(num)
@cython.binding
def set_delta(self, float value_delta_max):
self.cmin_max_stats_lst[0].set_delta(value_delta_max)
def __dealloc__(self):
del self.cmin_max_stats_lst
cdef class ResultsWrapper:
@cython.binding
def __cinit__(self, int num):
self.cresults = CSearchResults(num)
@cython.binding
def get_search_len(self):
return self.cresults.search_lens
cdef class Roots:
@cython.binding
def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list):
self.root_num = root_num
self.roots = new CRoots(root_num, legal_actions_list)
# Store legal_actions for access from Python
self.legal_actions_list = legal_actions_list
@cython.binding
def prepare(self, float root_noise_weight, list noises, list value_prefix_pool,
list policy_logits_pool, vector[int] & to_play_batch):
self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch)
@cython.binding
def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch):
self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch)
@cython.binding
def get_trajectories(self):
return self.roots[0].get_trajectories()
@cython.binding
def get_distributions(self):
return self.roots[0].get_distributions()
@cython.binding
def get_values(self):
return self.roots[0].get_values()
@cython.binding
def get_root_policies(self, MinMaxStatsList min_max_stats_lst):
return self.roots[0].get_root_policies(min_max_stats_lst.cmin_max_stats_lst)
@cython.binding
def get_best_actions(self):
return self.roots[0].get_best_actions()
# visualize related code
#def get_root(self, int index):
# return self.roots[index]
@cython.binding
def clear(self):
self.roots[0].clear()
@cython.binding
def get_legal_actions(self):
"""Get the legal actions list"""
return list(self.legal_actions_list)
def __dealloc__(self):
del self.roots
@property
def num(self):
return self.root_num
cdef class Node:
def __cinit__(self):
pass
def __cinit__(self, float prior, vector[int] & legal_actions):
pass
@cython.binding
def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix,
list policy_logits):
cdef vector[float] cpolicy = policy_logits
self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy)
@cython.binding
def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list,
list to_play_batch):
cdef int i
cdef vector[float] cvalue_prefixs = value_prefixs
cdef vector[float] cvalues = values
cdef vector[vector[float]] cpolicies = policies
cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch)
@cython.binding
def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list,
list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst):
cdef int i
cdef vector[float] cvalue_prefixs = value_prefixs
cdef vector[float] cvalues = values
cdef vector[vector[float]] cpolicies = policies
cdef vector[float] creuse_value_lst = reuse_value_lst
cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst)
# ========== MuZero/UCB 风格的遍历(备份版本) ==========
@cython.binding
def batch_traverse_ucb(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
ResultsWrapper results, list virtual_to_play_batch):
"""MuZero/UCB 风格的批量树遍历(备份版本)"""
cbatch_traverse_ucb(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst,
results.cresults, virtual_to_play_batch)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
# ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ==========
@cython.binding
def batch_traverse(Roots roots, MinMaxStatsList min_max_stats_lst, ResultsWrapper results,
int num_simulations, int simulation_idx, list gumbel_noises,
int current_num_top_actions, list virtual_to_play_batch):
"""EfficientZero V2 风格的批量树遍历,集成 Sequential Halving 逻辑"""
cdef vector[vector[float]] c_gumbel_noises = gumbel_noises
cbatch_traverse(roots.roots, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
num_simulations, simulation_idx, c_gumbel_noises,
current_num_top_actions, virtual_to_play_batch)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
@cython.binding
def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value):
cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
virtual_to_play_batch, true_action, reuse_value)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
@cython.binding
def batch_sequential_halving(Roots roots, list gumbel_noises, MinMaxStatsList min_max_stats_lst,
int current_phase, int current_num_top_actions):
cdef vector[vector[float]] c_gumbel_noises = gumbel_noises
return c_batch_sequential_halving(roots.roots, c_gumbel_noises, min_max_stats_lst.cmin_max_stats_lst,
current_phase, current_num_top_actions)