Skip to content

Commit 86f4e5d

Browse files
fixed param object shapes in jacobians and hessians
1 parent 5e7a03b commit 86f4e5d

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

batchglm/train/tf/nb_glm/hessians.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,8 @@ def _red(prev, cur):
502502
return tf.add(prev, cur)
503503

504504
params = model_vars.params
505-
p_shape_a = model_vars.a.shape[0]
506-
p_shape_b = model_vars.b.shape[0]
505+
p_shape_a = model_vars.a_var.shape[0] # This has to be _var to work with constraints.
506+
p_shape_b = model_vars.b_var.shape[0] # This has to be _var to work with constraints.
507507

508508
if iterator:
509509
H = op_utils.map_reduce(
@@ -671,8 +671,8 @@ def _red(prev, cur):
671671
return [tf.add(p, c) for p, c in zip(prev, cur)]
672672

673673
params = model_vars.params
674-
p_shape_a = model_vars.a.shape[0]
675-
p_shape_b = model_vars.b.shape[0]
674+
p_shape_a = model_vars.a_var.shape[0] # This has to be _var to work with constraints.
675+
p_shape_b = model_vars.b_var.shape[0] # This has to be _var to work with constraints.
676676

677677
if iterator:
678678
H = op_utils.map_reduce(
@@ -797,8 +797,8 @@ def _map(idx, data):
797797
constraints_loc=constraints_loc,
798798
constraints_scale=constraints_scale,
799799
params=model_vars.params,
800-
p_shape_a=model_vars.a.shape[0],
801-
p_shape_b=model_vars.b.shape[0],
800+
p_shape_a=model_vars.a_var.shape[0], # This has to be _var to work with constraints.
801+
p_shape_b=model_vars.b_var.shape[0], # This has to be _var to work with constraints.
802802
dtype=dtype,
803803
size_factors=size_factors
804804
)

batchglm/train/tf/nb_glm/jacobians.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def _red(prev, cur):
286286
return tf.add(prev, cur)
287287

288288
params = model_vars.params
289-
p_shape_a = model_vars.a.shape[0]
290-
p_shape_b = model_vars.b.shape[0]
289+
p_shape_a = model_vars.a_var.shape[0] # This has to be _var to work with constraints.
290+
p_shape_b = model_vars.b_var.shape[0] # This has to be _var to work with constraints.
291291

292292
if iterator:
293293
J = op_utils.map_reduce(
@@ -374,8 +374,8 @@ def _red(prev, cur):
374374
return tf.add(prev, cur)
375375

376376
params = model_vars.params
377-
p_shape_a = model_vars.a.shape[0]
378-
p_shape_b = model_vars.b.shape[0]
377+
p_shape_a = model_vars.a_var.shape[0] # This has to be _var to work with constraints.
378+
p_shape_b = model_vars.b_var.shape[0] # This has to be _var to work with constraints.
379379

380380
if iterator == True and batch_model is None:
381381
J = op_utils.map_reduce(

0 commit comments

Comments
 (0)