Skip to content

Commit 5d42742

Browse files
fixed some functions for interface to diffxpy
1 parent f2aff55 commit 5d42742

File tree

4 files changed

+41
-19
lines changed

4 files changed

+41
-19
lines changed

batchglm/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def parse_constraints(
566566
Parse constraint matrix into xarray.
567567
568568
:param dmat: Design matrix.
569-
:param a constraint matrix
569+
:param constraints: a constraint matrix
570570
:return: constraint matrix in xarray format
571571
"""
572572
constraints_ar = xr.DataArray(

batchglm/models/base_glm/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def parse_constraints(
7777
dmat: xr.Dataset,
7878
dims,
7979
constraints: np.ndarray = None,
80-
constraint_par_names: list = None,
80+
constraint_par_names: list = None
8181
) -> xr.DataArray:
8282
r"""
8383
Parser for constraint matrices.
@@ -90,7 +90,11 @@ def parse_constraints(
9090
"""
9191
if constraints is None:
9292
constraints = np.identity(n=dmat.shape[1])
93+
# Use given parameter names if constraint matrix is identity.
94+
par_names = dmat.coords[dims[0]]
9395
else:
96+
# Cannot use given parameter names if constraint matrix is not identity: Make up new ones.
97+
par_names = ["var_"+str(x) for x in range(constraints.shape[1])]
9498
assert constraints.shape[0] == dmat.shape[1], "constraint dimension mismatch"
9599

96100
constraints_mat = xr.DataArray(
@@ -99,7 +103,7 @@ def parse_constraints(
99103
)
100104
constraints_mat.coords[dims[0]] = dmat.coords[dims[0]]
101105
if constraint_par_names is None:
102-
constraint_par_names = ["var_"+str(x) for x in range(constraints_mat.shape[1])]
106+
constraint_par_names = par_names
103107

104108
constraints_mat.coords[dims[1]] = constraint_par_names
105109

batchglm/train/tf/glm_nb/estimator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def init_par(
187187
)
188188
else:
189189
init_a_xr = data_utils.xarray_from_data(init_a, dims=("loc_params", "features"))
190-
init_a_xr.coords["loc_params"] = input_data.constraints_loc.coords["loc_params"]
190+
init_a_xr.coords["loc_params"] = input_data.constraints_loc.coords["loc_params"].values
191191
init_mu = input_data.design_loc.dot(input_data.constraints_loc.dot(init_a_xr))
192192

193193
if size_factors_init is not None:
@@ -236,28 +236,28 @@ def init_par(
236236
else:
237237
# Locations model:
238238
if isinstance(init_a, str) and (init_a.lower() == "auto" or init_a.lower() == "init_model"):
239-
my_loc_names = set(input_data.design_loc_names.values)
240-
my_loc_names = my_loc_names.intersection(init_model.input_data.design_loc_names.values)
239+
my_loc_names = set(input_data.loc_names.values)
240+
my_loc_names = my_loc_names.intersection(set(init_model.input_data.loc_names.values))
241241

242242
init_loc = np.zeros([input_data.num_loc_params, input_data.num_features])
243243
for parm in my_loc_names:
244-
init_idx = np.where(init_model.input_data.design_loc_names == parm)
245-
my_idx = np.where(input_data.design_loc_names == parm)
246-
init_loc[my_idx] = init_model.par_link_loc[init_idx]
244+
init_idx = np.where(init_model.input_data.loc_names == parm)[0]
245+
my_idx = np.where(input_data.loc_names == parm)[0]
246+
init_loc[my_idx] = init_model.a_var[init_idx]
247247

248248
init_a = init_loc
249249
logger.debug("Using initialization based on input model for mean")
250250

251251
# Scale model:
252252
if isinstance(init_b, str) and (init_b.lower() == "auto" or init_b.lower() == "init_model"):
253-
my_scale_names = set(input_data.design_scale_names.values)
254-
my_scale_names = my_scale_names.intersection(init_model.input_data.design_scale_names.values)
253+
my_scale_names = set(input_data.scale_names.values)
254+
my_scale_names = my_scale_names.intersection(init_model.input_data.scale_names.values)
255255

256256
init_scale = np.zeros([input_data.num_scale_params, input_data.num_features])
257257
for parm in my_scale_names:
258-
init_idx = np.where(init_model.input_data.design_scale_names == parm)
259-
my_idx = np.where(input_data.design_scale_names == parm)
260-
init_scale[my_idx] = init_model.par_link_scale[init_idx]
258+
init_idx = np.where(init_model.input_data.scale_names == parm)[0]
259+
my_idx = np.where(input_data.scale_names == parm)[0]
260+
init_scale[my_idx] = init_model.b_var[init_idx]
261261

262262
init_b = init_scale
263263
logger.debug("Using initialization based on input model for dispersion")

batchglm/xarray_sparse/base.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def new_from_x(self, x):
6262
def dtype(self):
6363
return self.X.dtype
6464

65+
def astype(self, dtype):
66+
return self.new_from_x(self.X.astype(dtype))
67+
6568
@property
6669
def shape(self):
6770
return self.X.shape
@@ -128,12 +131,19 @@ def mean(self, dim: str = None, axis: int = None):
128131
else:
129132
return np.asarray(self.X.mean()).flatten()
130133

131-
def var(self, dim: str):
132-
assert dim in self.dims, "dim not recognized"
133-
axis = self.dims.index(dim)
134+
def var(self, dim: str = None, axis: int = None):
135+
assert not (dim is not None and axis is not None), "only supply dim or axis"
136+
if dim is not None:
137+
assert dim in self.dims, "dim not recognized"
138+
axis = self.dims.index(dim)
139+
elif axis is not None:
140+
assert axis < len(self.X.shape), "axis index out of range"
141+
else:
142+
assert False, "supply either dim or axis"
143+
134144
Xsq = self.square(copy=True)
135-
expect_x_sq = np.square(self.mean(dim=dim))
136-
expect_xsq = np.mean(Xsq, axis=axis)
145+
expect_x_sq = np.square(self.mean(axis=axis))
146+
expect_xsq = Xsq.mean(axis=axis)
137147
return np.asarray(expect_xsq - expect_x_sq).flatten()
138148

139149
def std(self, dim: str):
@@ -177,6 +187,14 @@ def group_vars(self, dim):
177187
def __copy__(self):
178188
return type(self)(self.X)
179189

190+
def __getitem__(self, key):
191+
if isinstance(key, np.ndarray) or isinstance(key, slice): # This is an observation wise slice!
192+
return self.new_from_x(x=self.X[key])
193+
elif key in self.coords:
194+
return self.coords[key]
195+
else:
196+
return self.__getattribute__(key)
197+
180198

181199
class SparseXArrayDataSet:
182200
"""

0 commit comments

Comments
 (0)