Skip to content

Commit eae5c98

Browse files
allowed overriding train args set in trianing strategy via train_kwargs
1 parent 1100e1a commit eae5c98

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

batchglm/models/base/estimator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import pprint
8+
import sys
89

910
try:
1011
import anndata
@@ -112,6 +113,14 @@ def train_sequence(
112113
logger.debug("training strategy:\n%s", pprint.pformat(training_strategy))
113114
for idx, d in enumerate(training_strategy):
114115
logger.debug("Beginning with training sequence #%d", idx + 1)
116+
# Override duplicate arguments with user choice:
117+
if np.any([x in list(d.keys()) for x in list(kwargs.keys())]):
118+
d = dict([(x, y) for x, y in d.items() if x not in list(kwargs.keys())])
119+
for x in [xx for xx in list(d.keys()) if xx in list(kwargs.keys())]:
120+
sys.stdout.write(
121+
"overrding %s from training strategy with value %s with new value %s\n" %
122+
(x, str(d[x]), str(kwargs[x]))
123+
)
115124
self.train(**d, **kwargs)
116125
logger.debug("Training sequence #%d complete", idx + 1)
117126

0 commit comments

Comments
 (0)