@@ -225,6 +225,8 @@ class ModelVars:
225225 a_var : tf .Variable
226226 b_var : tf .Variable
227227 params : tf .Variable
228+ converged : np .ndarray
229+
228230 """ Build tf.Variables to be optimzed and their constraints.
229231
230232 a_var and b_var slices of the tf.Variable params which contains
@@ -309,8 +311,13 @@ def __init__(
309311 axis = 0
310312 ), name = "params" )
311313
312- a_var = params [0 :init_a .shape [0 ]]
313- b_var = params [init_a .shape [0 ]:]
314+ params_by_gene = [tf .expand_dims (params [:, i ], axis = - 1 ) for i in range (params .shape [1 ])]
315+ a_by_gene = [x [0 :init_a .shape [0 ],:] for x in params_by_gene ]
316+ b_by_gene = [x [init_a .shape [0 ]:, :] for x in params_by_gene ]
317+ a_var = tf .concat (a_by_gene , axis = 1 )
318+ b_var = tf .concat (b_by_gene , axis = 1 )
319+ #a_var = params[0:init_a.shape[0]]
320+ #b_var = params[init_a.shape[0]:]
314321
315322 # Define first layer of computation graph on identifiable variables
316323 # to yield dependent set of parameters of model for each location
@@ -334,3 +341,10 @@ def __init__(
334341 self .a_var = a_var
335342 self .b_var = b_var
336343 self .params = params
344+ # Properties to follow gene-wise convergence.
345+ self .params_by_gene = params_by_gene
346+ self .a_by_gene = a_by_gene
347+ self .b_by_gene = b_by_gene
348+ self .converged = np .repeat (a = False , repeats = self .params .shape [1 ]) # Initialise to non-converged.
349+ self .n_features = self .params .shape [1 ]
350+
0 commit comments