Skip to content

Commit a573fcd

Browse files
removed randomness in model based initialisation, set all to zero
1 parent d99ac37 commit a573fcd

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,7 @@ def init_par(
190190
my_loc_names = set(self.input_data.design_loc_names.values)
191191
my_loc_names = my_loc_names.intersection(init_model.input_data.design_loc_names.values)
192192

193-
init_loc = np.random.uniform(
194-
low=np.nextafter(0, 1, dtype=self.input_data.X.dtype),
195-
high=np.sqrt(np.nextafter(0, 1, dtype=self.input_data.X.dtype)),
196-
size=(self.input_data.num_design_loc_params, self.input_data.num_features)
197-
)
193+
init_loc = np.zeros([self.input_data.num_loc_params, self.input_data.num_features])
198194
for parm in my_loc_names:
199195
init_idx = np.where(init_model.input_data.design_loc_names == parm)
200196
my_idx = np.where(input_data.design_loc_names == parm)
@@ -208,11 +204,7 @@ def init_par(
208204
my_scale_names = set(input_data.design_scale_names.values)
209205
my_scale_names = my_scale_names.intersection(init_model.input_data.design_scale_names.values)
210206

211-
init_scale = np.random.uniform(
212-
low=np.nextafter(0, 1, dtype=self.input_data.X.dtype),
213-
high=np.sqrt(np.nextafter(0, 1, dtype=self.input_data.X.dtype)),
214-
size=(self.input_data.num_design_scale_params, self.input_data.num_features)
215-
)
207+
init_scale = np.zeros([self.input_data.num_scale_params, self.input_data.num_features])
216208
for parm in my_scale_names:
217209
init_idx = np.where(init_model.input_data.design_scale_names == parm)
218210
my_idx = np.where(input_data.design_scale_names == parm)

0 commit comments

Comments
 (0)