|
| 1 | +import ps_pb2 as pslib |
| 2 | + |
| 3 | +class Server(object): |
| 4 | + def __init__(self): |
| 5 | + pass |
| 6 | + |
| 7 | + |
| 8 | +class Worker(object): |
| 9 | + def __init__(self): |
| 10 | + pass |
| 11 | + |
| 12 | + |
| 13 | +class DownpourServer(Server): |
| 14 | + def __init__(self): |
| 15 | + #self.server_ = pslib.ServerParameter().downpour_server_param |
| 16 | + self.server_ = pslib.ServerParameter() |
| 17 | + |
| 18 | + def add_sparse_table(self, table_id, learning_rate, |
| 19 | + slot_key, slot_value_var, slot_grad_var): |
| 20 | + #table = self.server_.downpour_table_param.add() |
| 21 | + table = self.server_.downpour_server_param.downpour_table_param.add() |
| 22 | + table.table_id = table_id |
| 23 | + table.type = PS_SPARSE_TABLE |
| 24 | + table.accessor.accessor_class = "DownpourFeatureValueAccessor" |
| 25 | + table.accessor.dense_sgd_param.adam.learning_rate = learning_rate |
| 26 | + table.accessor.fea_dim = slot_value_var[0].shape[1] |
| 27 | + |
| 28 | + def add_dense_table(self, table_id, learning_rate, |
| 29 | + param_var, grad_var): |
| 30 | + #table = self.server_.downpour_table_param.add() |
| 31 | + table = self.server_.downpour_server_param.downpour_table_param.add() |
| 32 | + table.table_id = table_id |
| 33 | + table.type = PS_DENSE_TABLE |
| 34 | + table.accessor.accessor_class = "DownpourDenseValueAccessor" |
| 35 | + table.accessor.sparse_sgd_param.learning_rate = learning_rate |
| 36 | + table.accessor.fea_dim = 1 |
| 37 | + #table.accessor.fea_dim = reduce(lambda x, y: x.shape, 1 for x in param_var) |
| 38 | + |
| 39 | + def get_desc(self): |
| 40 | + return self.server_ |
| 41 | + |
| 42 | + |
| 43 | +class DownpourWorker(Worker): |
| 44 | + def __init__(self, window): |
| 45 | + self.window = window |
| 46 | + #self.worker_ = pslib.WorkerParameter().downpour_worker_param |
| 47 | + #self.worker_ = pslib.WorkerParameter() |
| 48 | + self.worker_ = pslib.DownpourTrainerParameter() |
| 49 | + #self.worker_.pull_dense_per_batch = window |
| 50 | + #self.worker_.push_dense_per_batch = window |
| 51 | + #self.worker_.downpour_worker_param.pull_dense_per_batch = window |
| 52 | + #self.worker_.downpour_worker_param.push_dense_per_batch = window |
| 53 | + self.worker_.pull_dense_per_batch = window |
| 54 | + self.worker_.push_dense_per_batch = window |
| 55 | + print(self.worker_) |
| 56 | + |
| 57 | + def add_sparse_table(self, table_id, |
| 58 | + slot_keys, slot_value_vars, slot_grad_vars): |
| 59 | + #table = self.worker_.sparse_table.add() |
| 60 | + table = self.worker_.downpour_worker_param.sparse_table.add() |
| 61 | + table.table_id = table_id |
| 62 | + table.slot.extend(slot_keys) |
| 63 | + self.worker_.extend([grad.name for grad in slot_grad_vars]) |
| 64 | + |
| 65 | + def add_dense_table(self, table_id, param_vars, grad_vars): |
| 66 | + #table = self.worker_.dense_table.add() |
| 67 | + table = self.worker_.downpour_worker_param.dense_table.add() |
| 68 | + table.table_id = table_id |
| 69 | + table.dense_variable_name.extend([p.name for p in param_vars]) |
| 70 | + table.dense_gradient_variable_name.extend([g.name for g in grad_vars]) |
| 71 | + |
| 72 | + def get_desc(self): |
| 73 | + return self.worker_ |
0 commit comments