Skip to content

Commit 6252825

Browse files
Merge pull request #82 from theislab/dev
Dev
2 parents bbe9539 + 9f6fcdb commit 6252825

File tree

9 files changed

+72
-10
lines changed

9 files changed

+72
-10
lines changed

batchglm/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33

44
from . import models
55
from . import data
6+
from . import typing
67
from . import utils
78
from .. import pkg_constants

batchglm/api/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from batchglm.models.base.estimator import EstimatorBaseTyping
2+
from batchglm.models.base.input import InputDataBaseTyping

batchglm/models/base/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .input import _InputDataBase
2-
from .estimator import _EstimatorBase
1+
from .input import _InputDataBase, InputDataBaseTyping
2+
from .estimator import _EstimatorBase, EstimatorBaseTyping
33
from .model import _ModelBase
44
from .simulator import _SimulatorBase

batchglm/models/base/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,9 @@ def _plot_deviation(
291291
else:
292292
return
293293

294+
295+
class EstimatorBaseTyping(_EstimatorBase):
296+
r"""
297+
Estimator base class used for typing in other packages.
298+
"""
299+

batchglm/models/base/input.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,8 @@ def fetch_x_sparse(self, idx):
9090
data_idx = np.squeeze(data_idx, axis=0)
9191

9292
return data_idx, data_val, data_shape
93+
94+
class InputDataBaseTyping:
95+
"""
96+
Input data base class used for typing in other packages.
97+
"""

batchglm/models/base/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import abc
2-
import numpy as np
32
from typing import Union, Any, Dict, Iterable
43
import logging
54

batchglm/models/base_glm/input.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class InputDataGLM(_InputDataBase):
1717
"""
1818
Input data for Generalized Linear Models (GLMs).
1919
"""
20+
loc_names: list
21+
design_loc_names: list
22+
scale_names: list
23+
design_scale_names: list
2024

2125
def __init__(
2226
self,
@@ -94,8 +98,8 @@ def __init__(
9498

9599
self.design_loc = design_loc
96100
self.design_scale = design_scale
97-
self.design_loc_names = design_loc_names
98-
self.design_scale_names = design_scale_names
101+
self._design_loc_names = design_loc_names
102+
self._design_scale_names = design_scale_names
99103

100104
constraints_loc, loc_names = parse_constraints(
101105
dmat=design_loc,
@@ -111,11 +115,27 @@ def __init__(
111115
)
112116
self.constraints_loc = constraints_loc
113117
self.constraints_scale = constraints_scale
114-
self.loc_names = loc_names
115-
self.scale_names = scale_names
118+
self._loc_names = loc_names
119+
self._scale_names = scale_names
116120

117121
self.size_factors = size_factors
118122

123+
@property
124+
def design_loc_names(self):
125+
return self._design_loc_names
126+
127+
@property
128+
def design_scale_names(self):
129+
return self._design_scale_names
130+
131+
@property
132+
def loc_names(self):
133+
return self._loc_names
134+
135+
@property
136+
def scale_names(self):
137+
return self._scale_names
138+
119139
@property
120140
def num_design_loc_params(self):
121141
return self.design_loc.shape[1]

batchglm/models/base_glm/model.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
anndata = None
88

99
from .external import _ModelBase
10+
from .input import InputDataGLM
1011

1112

1213
class _ModelGLM(_ModelBase, metaclass=abc.ABCMeta):
@@ -26,7 +27,7 @@ class _ModelGLM(_ModelBase, metaclass=abc.ABCMeta):
2627

2728
def __init__(
2829
self,
29-
input_data
30+
input_data: InputDataGLM
3031
):
3132
_ModelBase.__init__(
3233
self=self,
@@ -63,6 +64,34 @@ def constraints_scale(self) -> np.ndarray:
6364
else:
6465
return self.input_data.constraints_scale
6566

67+
@property
68+
def design_loc_names(self) -> list:
69+
if self.input_data is None:
70+
return None
71+
else:
72+
return self.input_data.design_loc_names
73+
74+
@property
75+
def design_scale_names(self) -> list:
76+
if self.input_data is None:
77+
return None
78+
else:
79+
return self.input_data.design_scale_names
80+
81+
@property
82+
def loc_names(self) -> list:
83+
if self.input_data is None:
84+
return None
85+
else:
86+
return self.input_data.loc_names
87+
88+
@property
89+
def scale_names(self) -> list:
90+
if self.input_data is None:
91+
return None
92+
else:
93+
return self.input_data.scale_names
94+
6695
@abc.abstractmethod
6796
def eta_loc(self) -> np.ndarray:
6897
pass

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
long_description_content_type="text/markdown",
1818
packages=find_packages(),
1919
install_requires=[
20-
'tensorflow==1.14.0',
20+
'tensorflow>=1.14.0',
2121
'tensorflow-probability>=0.7',
22-
'numpy==1.16.4',
22+
'numpy>=1.16.4',
2323
'scipy>=1.2.1',
2424
'pandas',
2525
'dask',

0 commit comments

Comments
 (0)