@@ -483,30 +483,39 @@ def build_equality_constraints_string(
483483 E.g. ["batch1 + batch2 + batch3 = 0"]
484484 :return: a constraint matrix
485485 """
486- # TODO: automatically generate string constraints from factors
486+ n_par_all = dmat .data_vars ['design' ].values .shape [1 ]
487+ n_par_free = n_par_all - len (constraints )
488+
487489 di = patsy .DesignInfo (dmat .coords ["design_params" ].values )
488490 constraint_ls = [di .linear_constraint (x ).coefs [0 ] for x in constraints ]
489- idx_constrained = [np .where (x == 1 )[0 ][0 ] for x in constraint_ls ]
490- idx_unconstr = list (
491- set (list (range ( dmat . data_vars [ "design" ]. shape [ 1 ]))) -
492- set (list ( idx_constrained ) )
493- )
491+ idx_constr = np . asarray ( [np .where (x == 1 )[0 ][0 ] for x in constraint_ls ])
492+ idx_depending = [ np . where ( x == 1 )[ 0 ][ 1 :] for x in constraint_ls ]
493+ idx_unconstr = np . asarray (list (
494+ set (np . asarray ( range ( n_par_all ))) - set ( idx_constr )
495+ ))
494496
495497 dmat_var = xr .DataArray (
496498 dims = [dmat .data_vars ['design' ].dims [0 ], "params" ],
497499 data = dmat .data_vars ["design" ][:,idx_unconstr ],
498500 coords = {dmat .data_vars ['design' ].dims [0 ]: dmat .coords ["observations" ].values ,
499501 "params" : dmat .coords ["design_params" ].values [idx_unconstr ]}
500502 )
501- constraint_mat = np .vstack (constraint_ls )[:,idx_unconstr ]
502503
503- constraints = np .vstack ([
504- np .identity (n = len (idx_unconstr )),
505- - constraint_mat
506- ])
504+ constraint_mat = np .zeros ([n_par_all , n_par_free ])
505+ for i in range (n_par_all ):
506+ if i in idx_constr :
507+ idx_dep_i = idx_depending [np .where (idx_constr == i )[0 ][0 ]]
508+ idx_dep_i = np .asarray ([np .where (idx_unconstr == x )[0 ] for x in idx_dep_i ])
509+ constraint_mat [i , :] = 0
510+ constraint_mat [i , idx_dep_i ] = - 1
511+ else :
512+ idx_unconstr_i = np .where (idx_unconstr == i )
513+ constraint_mat [i , :] = 0
514+ constraint_mat [i , idx_unconstr_i ] = 1
515+
507516 constraints_ar = parse_constraints (
508517 dmat = dmat ,
509- constraints = constraints ,
518+ constraints = constraint_mat ,
510519 dims = dims
511520 )
512521
0 commit comments