Skip to content

Commit cd8c917

Browse files
committed
added eager executing tf2 training with nb noise
1 parent 238bb5d commit cd8c917

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4108
-0
lines changed

batchglm/train/tf2/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import glm_nb as nb
2+
from . import glm_norm as norm
3+
from . import glm_beta as beta
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .estimator import TFEstimator
2+
from .model import ProcessModelBase, ModelBase, LossBase
3+
from .optim import OptimizerBase
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from .external import pkg_constants, TrainingStrategies
2+
from .model import ModelBase, LossBase
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
7+
8+
class TFEstimator:
9+
model: ModelBase
10+
loss: LossBase
11+
12+
def __init__(self, input_data, dtype):
13+
14+
self._input_data = input_data
15+
self.dtype = dtype
16+
17+
def _train(
18+
self,
19+
batched_model: bool,
20+
batch_size: int,
21+
optimizer_object: tf.keras.optimizers.Optimizer,
22+
optimizer_enum: TrainingStrategies,
23+
convergence_criteria: str,
24+
stopping_criteria: int,
25+
autograd: bool,
26+
featurewise: bool,
27+
benchmark: bool
28+
):
29+
pass
30+
31+
def fetch_fn(self, idx):
32+
pass
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#from batchglm.models.base import _Estimator_Base
2+
#from batchglm.xarray_sparse import SparseXArrayDataArray, SparseXArrayDataSet
3+
from batchglm.train.tf2.base_glm.training_strategies import TrainingStrategies
4+
#import batchglm.utils.stats as stat_utils
5+
from batchglm import pkg_constants

batchglm/train/tf2/base/model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import abc
2+
import logging
3+
import tensorflow as tf
4+
import numpy as np
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class ModelBase(tf.keras.Model, metaclass=abc.ABCMeta):
10+
11+
def __init__(self):
12+
super(ModelBase, self).__init__()
13+
14+
@abc.abstractmethod
15+
def call(self, inputs, training=False, mask=None):
16+
pass
17+
18+
19+
class LossBase(tf.keras.losses.Loss, metaclass=abc.ABCMeta):
20+
21+
def __init__(self):
22+
super(LossBase, self).__init__()
23+
24+
@abc.abstractmethod
25+
def call(self, y_true, y_pred):
26+
pass
27+
28+
29+
class ProcessModelBase:
30+
31+
@abc.abstractmethod
32+
def param_bounds(self, dtype):
33+
pass
34+
35+
def tf_clip_param(
36+
self,
37+
param,
38+
name
39+
):
40+
bounds_min, bounds_max = self.param_bounds(param.dtype)
41+
return tf.clip_by_value(
42+
param,
43+
bounds_min[name],
44+
bounds_max[name]
45+
)
46+
47+
def np_clip_param(
48+
self,
49+
param,
50+
name
51+
):
52+
bounds_min, bounds_max = self.param_bounds(param.dtype)
53+
return np.clip(
54+
param,
55+
bounds_min[name],
56+
bounds_max[name]
57+
)

batchglm/train/tf2/base/optim.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import abc
2+
import logging
3+
import tensorflow as tf
4+
5+
logger = logging.getLogger("batchglm")
6+
7+
8+
class OptimizerBase(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta):
9+
10+
def __init__(self, name):
11+
super(OptimizerBase, self).__init__(name=name)
12+
13+
@abc.abstractmethod
14+
def _resource_apply_dense(self, grad, handle):
15+
pass
16+
17+
@abc.abstractmethod
18+
def _resource_apply_sparse(self, grad, handle, apply_state):
19+
pass
20+
21+
@abc.abstractmethod
22+
def _create_slots(self):
23+
pass
24+
25+
"""
26+
@property
27+
@abc.abstractmethod
28+
def vars(self):
29+
pass
30+
31+
@property
32+
@abc.abstractmethod
33+
def gradients(self):
34+
return None
35+
36+
@property
37+
@abc.abstractmethod
38+
def hessians(self):
39+
pass
40+
41+
@property
42+
@abc.abstractmethod
43+
def fims(self):
44+
pass
45+
46+
@abc.abstractmethod
47+
def step(self, learning_rate):
48+
pass
49+
"""
50+
@abc.abstractmethod
51+
def get_config(self):
52+
pass
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Classes with GLM specific code.
2+
All noise models that are in the GLM category inherit all of these classes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .processModel import ProcessModelGLM
2+
from .model import GLM, LossGLM
3+
4+
from .estimator import Estimator
5+
from .vars import ModelVarsGLM
6+
from .layers import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM
7+
from .layers import LikelihoodGLM, UnpackParamsGLM
8+
from .layers_gradients import JacobianGLM, HessianGLM, FIMGLM
9+
from .optim import NR, IRLS
10+
from .training_strategies import TrainingStrategies

0 commit comments

Comments
 (0)