Skip to content

Commit 886a241

Browse files
the group wise means were incorrect if models mismatched, recompute always now.
1 parent d63cdfc commit 886a241

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,7 @@ def __init__(
865865
Xdiff = X - np.exp(input_data.design_loc.dot(init_a_xr))
866866
variance = np.square(Xdiff).groupby("group").mean(dim="observations")
867867

868-
if groupwise_means is None:
869-
groupwise_means = X.groupby("group").mean(dim="observations")
868+
groupwise_means = X.groupby("group").mean(dim="observations")
870869
denominator = np.fmax(variance - groupwise_means, 0)
871870
denominator = np.nextafter(0, 1, out=denominator.values, where=denominator == 0,
872871
dtype=denominator.dtype)

0 commit comments

Comments
 (0)