Skip to content

Commit 3a11bc1

Browse files
fixed bug in settting up constraints for unequal model sizes
1 parent 68f227c commit 3a11bc1

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

batchglm/train/tf/base_glm/estimator_graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,10 +649,12 @@ def __init__(
649649

650650
self.constraints_loc = self._set_constraints(
651651
constraints=constraints_loc,
652+
num_design_params=self.num_design_loc_params,
652653
dtype=dtype
653654
)
654655
self.constraints_scale = self._set_constraints(
655656
constraints=constraints_scale,
657+
num_design_params=self.num_design_scale_params,
656658
dtype=dtype
657659
)
658660

@@ -724,15 +726,16 @@ def _set_out_var(
724726
def _set_constraints(
725727
self,
726728
constraints,
729+
num_design_params,
727730
dtype
728731
):
729732
if constraints is None:
730733
return tf.eye(
731-
num_rows=tf.constant(self.num_design_loc_params, shape=(), dtype="int32"),
734+
num_rows=tf.constant(num_design_params, shape=(), dtype="int32"),
732735
dtype=dtype
733736
)
734737
else:
735-
assert constraints.shape[0] == self.num_design_loc_params, "constraint dimension mismatch"
738+
assert constraints.shape[0] == num_design_params, "constraint dimension mismatch"
736739
return tf.cast(constraints, dtype=dtype)
737740

738741
@abc.abstractmethod

0 commit comments

Comments
 (0)