Skip to content

Commit 68f227c

Browse files
adapted estimator graph to handle no training
ie direct passing of closed form estimators
1 parent c9f8051 commit 68f227c

File tree

6 files changed

+306
-282
lines changed

6 files changed

+306
-282
lines changed

batchglm/models/base/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def loss(self):
8080

8181
@property
8282
@abc.abstractmethod
83-
def gradient(self):
83+
def gradients(self):
8484
pass
8585

8686

batchglm/models/base_glm/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ESTIMATOR_PARAMS.update({
1313
"loss": (),
1414
"log_likelihood": ("features",),
15-
"gradient": ("features",),
15+
"gradients": ("features",),
1616
"hessians": ("features", "delta_var0", "delta_var1"),
1717
"fisher_inv": ("features", "delta_var0", "delta_var1"),
1818
})
@@ -32,8 +32,8 @@ def log_likelihood(self):
3232
return self.params["log_likelihood"]
3333

3434
@property
35-
def gradient(self):
36-
return self.params["gradient"]
35+
def gradients(self):
36+
return self.params["gradients"]
3737

3838
@property
3939
def hessians(self):

batchglm/models/glm_nb/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, estim: AbstractEstimator):
2323
# causes evaluation of the properties that have not been computed during
2424
# training, such as the hessian.
2525
params = estim.to_xarray(
26-
["a_var", "b_var", "loss", "log_likelihood", "gradient", "hessians", "fisher_inv"],
26+
["a_var", "b_var", "loss", "log_likelihood", "gradients", "hessians", "fisher_inv"],
2727
coords=input_data.data
2828
)
2929

0 commit comments

Comments
 (0)