@@ -77,9 +77,15 @@ def np_clip_param(param, name):
7777 )
7878
7979
80- def apply_constraints (constraints : np .ndarray , var : tf .Variable , dtype : str ):
80+ def apply_constraints (
81+ constraints : np .ndarray ,
82+ dtype : str ,
83+ var_all : tf .Variable = None ,
84+ var_indep : tf .Tensor = None
85+ ):
8186 """ Iteratively build depend variables from other variables via constraints
8287
88+ :type var_all: object
8389 :param constraints: Array with constraints in rows and model parameters in columns.
8490 Each constraint contains non-zero entries for the a of parameters that
8591 has to sum to zero. This constraint is enforced by binding one parameter
@@ -88,16 +94,21 @@ def apply_constraints(constraints: np.ndarray, var: tf.Variable, dtype: str):
8894 parameter is indicated by a -1 in this array, the independent parameters
8995 of that constraint (which may be dependent at an earlier constraint)
9096 are indicated by a 1.
91- :param var: Variable tensor features x independent parameters.
97+ :param var_all: Variable tensor features x independent parameters.
98+ All model parameters.
99+ :param var_all: Variable tensor features x independent parameters.
100+ Only independent model parameters, ie. not parameters defined by constraints.
92101 :param dtype: Precision used in tensorflow.
93102
94103 :return: Full model parameter matrix with dependent parameters.
95104 """
96105
97106 # Find all independent variables:
98- idx_indep = np .where (np .all (constraints == - 1 , axis = 0 ))[0 ]
107+ idx_indep = np .where (np .all (constraints != - 1 , axis = 0 ))[0 ]
108+ idx_indep .astype (dtype = np .int64 )
99109 # Relate constraints to dependent variables:
100110 idx_dep = np .array ([np .where (constr == - 1 )[0 ] for constr in constraints ])
111+ idx_dep .astype (dtype = np .int64 )
101112 # Only choose dependent variable which was not already defined above:
102113 idx_dep = np .concatenate ([
103114 x [[xx not in np .concatenate (idx_dep [:i ]) for xx in x ]] if i > 0 else x
@@ -109,7 +120,13 @@ def apply_constraints(constraints: np.ndarray, var: tf.Variable, dtype: str):
109120 # tensor is initialised with the independent variables var
110121 # and is grown by one varibale in each iteration until
111122 # all variables are there.
112- x = var
123+ if var_all is None :
124+ x = var_indep
125+ elif var_indep is None :
126+ x = tf .gather (params = var_all , indices = idx_indep , axis = 0 )
127+ else :
128+ raise ValueError ("only give var_all or var_indep to apply_constraints." )
129+
113130 for i in range (constraints .shape [0 ]):
114131 idx_var_i = np .concatenate ([idx_indep , idx_dep [:i ]])
115132 constraint_model = constraints [[i ], :][:, idx_var_i ]
@@ -150,12 +167,11 @@ def __init__(
150167 # Define first layer of computation graph on identifiable variables
151168 # to yield dependent set of parameters of model for each location
152169 # and scale model.
153-
154170 if constraints_loc is not None :
155- a = apply_constraints (constraints_loc , a , dtype = dtype )
171+ a = apply_constraints (constraints = constraints_loc , var_all = a , dtype = dtype )
156172
157173 if constraints_scale is not None :
158- b = apply_constraints (constraints_scale , b , dtype = dtype )
174+ b = apply_constraints (constraints = constraints_scale , var_all = b , dtype = dtype )
159175
160176 with tf .name_scope ("mu" ):
161177 log_mu = tf .matmul (design_loc , a , name = "log_mu_obs" )
@@ -299,14 +315,13 @@ def __init__(
299315 # Define first layer of computation graph on identifiable variables
300316 # to yield dependent set of parameters of model for each location
301317 # and scale model.
302-
303318 if constraints_loc is not None :
304- a = apply_constraints (constraints_loc , a_var , dtype = dtype )
319+ a = apply_constraints (constraints = constraints_loc , var_indep = a_var , dtype = dtype )
305320 else :
306321 a = a_var
307322
308323 if constraints_scale is not None :
309- b = apply_constraints (constraints_scale , b_var , dtype = dtype )
324+ b = apply_constraints (constraints = constraints_scale , var_indep = b_var , dtype = dtype )
310325 else :
311326 b = b_var
312327
0 commit comments