@@ -502,8 +502,8 @@ def _red(prev, cur):
502502 return tf .add (prev , cur )
503503
504504 params = model_vars .params
505- p_shape_a = model_vars .a .shape [0 ]
506- p_shape_b = model_vars .b .shape [0 ]
505+ p_shape_a = model_vars .a_var .shape [0 ] # This has to be _var to work with constraints.
506+ p_shape_b = model_vars .b_var .shape [0 ] # This has to be _var to work with constraints.
507507
508508 if iterator :
509509 H = op_utils .map_reduce (
@@ -671,8 +671,8 @@ def _red(prev, cur):
671671 return [tf .add (p , c ) for p , c in zip (prev , cur )]
672672
673673 params = model_vars .params
674- p_shape_a = model_vars .a .shape [0 ]
675- p_shape_b = model_vars .b .shape [0 ]
674+ p_shape_a = model_vars .a_var .shape [0 ] # This has to be _var to work with constraints.
675+ p_shape_b = model_vars .b_var .shape [0 ] # This has to be _var to work with constraints.
676676
677677 if iterator :
678678 H = op_utils .map_reduce (
@@ -797,8 +797,8 @@ def _map(idx, data):
797797 constraints_loc = constraints_loc ,
798798 constraints_scale = constraints_scale ,
799799 params = model_vars .params ,
800- p_shape_a = model_vars .a .shape [0 ],
801- p_shape_b = model_vars .b .shape [0 ],
800+ p_shape_a = model_vars .a_var .shape [0 ], # This has to be _var to work with constraints.
801+ p_shape_b = model_vars .b_var .shape [0 ], # This has to be _var to work with constraints.
802802 dtype = dtype ,
803803 size_factors = size_factors
804804 )
0 commit comments