Skip to content

Commit 154e871

Browse files
committed
Merge branch 'rsa_dev' into dev
2 parents 66f2bf9 + 1e3032d commit 154e871

File tree

9 files changed

+132
-102
lines changed

9 files changed

+132
-102
lines changed

batchglm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22

33
__version__ = get_versions()['version']
44
del get_versions
5+
6+
from .log_cfg import logger, unconfigure_logging, enable_logging

batchglm/api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from .. import __version__
2+
from ..log_cfg import logger, unconfigure_logging, enable_logging
3+
14
from . import models
25
from . import data
36
from . import utils

batchglm/benchmark/nb_glm/base.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
from typing import Tuple
2+
13
import os
4+
import logging
25

36
from collections import OrderedDict
47
import itertools
58

9+
import xarray as xr
610
import pandas as pd
711
import yaml
812

913
from batchglm.api.models.nb_glm import Simulator, Estimator
1014

15+
logger = logging.getLogger(__name__)
1116

1217
def init_benchmark(
1318
root_dir: str,
@@ -166,3 +171,51 @@ def run_benchmark(root_dir: str, sample: str, config_file="config.yml"):
166171
os.remove(os.path.join(working_dir, "lock"))
167172
touch(os.path.join(working_dir, "ready"))
168173
print("\t[OK]")
174+
175+
176+
def load_benchmark_dataset(root_dir: str, config_file="config.yml") -> Tuple[Simulator, xr.Dataset]:
177+
config_file = os.path.join(root_dir, config_file)
178+
with open(config_file, mode="r") as f:
179+
config = yaml.load(f)
180+
181+
sim_data_file = os.path.join(root_dir, config["sim_data"])
182+
sim = Simulator()
183+
sim.load(sim_data_file)
184+
185+
benchmark_samples = config["benchmark_samples"]
186+
benchmark_data = []
187+
for smpl, cfg in benchmark_samples.items():
188+
wd = cfg["working_dir"]
189+
logger.info("opening working dir: %s", wd)
190+
ds_path = os.path.join(root_dir, wd, "cache.zarr")
191+
try: # try open zarr cache
192+
data = xr.open_zarr(ds_path)
193+
logger.info("using zarr cache: %s", os.path.join(wd, "cache.zarr"))
194+
except: # open netcdf4 files
195+
logger.info("loading step-wise netcdf4 files...")
196+
ncdf_data = xr.open_mfdataset(
197+
os.path.join(root_dir, cfg["working_dir"], "estimation-*.h5"),
198+
engine="netcdf4",
199+
concat_dim="step",
200+
autoclose=True,
201+
parallel=True,
202+
)
203+
ncdf_data = ncdf_data.sortby("global_step")
204+
ncdf_data.coords["benchmark"] = smpl
205+
logger.info("loading step-wise netcdf4 files ready")
206+
207+
try: # try to save data in zarr cache
208+
zarr_data = ncdf_data.to_zarr(ds_path)
209+
logger.info("Stored data in zarr cache")
210+
211+
# close netcdf4 data sets
212+
ncdf_data.close()
213+
del ncdf_data
214+
data = zarr_data
215+
except: # use netcdf4 since zarr does not seem to work
216+
data = ncdf_data
217+
218+
benchmark_data.append(data)
219+
benchmark_data = xr.auto_combine(benchmark_data, concat_dim="benchmark", coords="all")
220+
221+
return sim, benchmark_data

batchglm/benchmark/nb_glm/convergence.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,34 @@
22

33
import os
44
import shutil
5+
import logging
56

67
import scipy.stats
78
import numpy as np
89
import xarray as xr
910
import yaml
1011

11-
from .base import init_benchmark, get_benchmark_samples, run_benchmark, Simulator
12+
from .base import init_benchmark, get_benchmark_samples, run_benchmark, load_benchmark_dataset
13+
from .base import Simulator
1214

1315
import batchglm.utils.stats as stat_utils
1416

15-
16-
def load_benchmark_dataset(root_dir: str, config_file="config.yml") -> Tuple[Simulator, xr.Dataset, dict]:
17-
config_file = os.path.join(root_dir, config_file)
18-
with open(config_file, mode="r") as f:
19-
config = yaml.load(f)
20-
21-
sim_data_file = os.path.join(root_dir, config["sim_data"])
22-
sim = Simulator()
23-
sim.load(sim_data_file)
24-
25-
benchmark_samples = config["benchmark_samples"]
26-
benchmark_data = []
27-
for smpl, cfg in benchmark_samples.items():
28-
data = xr.open_mfdataset(
29-
os.path.join(root_dir, cfg["working_dir"], "estimation-*.h5"),
30-
engine="netcdf4",
31-
concat_dim="step",
32-
autoclose=True,
33-
parallel=True,
34-
)
35-
data = data.sortby("global_step")
36-
data.coords["benchmark"] = smpl
37-
benchmark_data.append(data)
38-
benchmark_data = xr.auto_combine(benchmark_data, concat_dim="benchmark", coords="all")
39-
40-
return sim, benchmark_data, benchmark_samples
17+
logger = logging.getLogger(__name__)
4118

4219

4320
def plot_benchmark(root_dir: str, config_file="config.yml"):
44-
print("loading config...", end="", flush=True)
21+
logger.info("loading config...", end="", flush=True)
4522
config_file = os.path.join(root_dir, config_file)
4623
with open(config_file, mode="r") as f:
4724
config = yaml.load(f)
48-
print("\t[OK]")
25+
logger.info("\t[OK]")
4926

5027
plot_dir = os.path.join(root_dir, config["plot_dir"])
5128

52-
print("loading data...", end="", flush=True)
53-
sim, benchmark_data, benchmark_sample_config = load_benchmark_dataset(root_dir)
29+
logger.info("loading data...", end="", flush=True)
30+
sim, benchmark_data = load_benchmark_dataset(root_dir)
5431
benchmark_data.coords["time_elapsed"] = benchmark_data.time_elapsed.cumsum("step")
55-
print("\t[OK]")
32+
logger.info("\t[OK]")
5633

5734
import plotnine as pn
5835
import matplotlib.pyplot as plt
@@ -83,7 +60,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
8360
plot = plot + pn.scale_y_log10()
8461
plot.save(os.path.join(plot_dir, name_prefix + ".step.svg"), format="svg")
8562

86-
print("plotting...")
63+
logger.info("plotting...")
8764
val: xr.DataArray = stat_utils.rmsd(
8865
np.exp(xr.DataArray(sim.params["a"][0], dims=("features",))),
8966
np.exp(benchmark_data.a.isel(design_loc_params=0)), axis=[0])
@@ -98,7 +75,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
9875
plot_stat(val, "loss", "loss")
9976

10077
def plot_pval(window_size):
101-
print("plotting p-value with window size: %d" % window_size)
78+
logger.info("plotting p-value with window size: %d" % window_size)
10279

10380
roll1 = benchmark_data.loss.rolling(step=window_size)
10481
roll2 = benchmark_data.loss.roll(step=window_size).rolling(step=window_size)
@@ -133,7 +110,7 @@ def plot_pval(window_size):
133110
plt.close()
134111

135112
def plot_loss_rolling_mean(window_size):
136-
print("plotting rolling mean of batch loss with window size: %d" % window_size)
113+
logger.info("plotting rolling mean of batch loss with window size: %d" % window_size)
137114

138115
benchmark_data.loss.rolling(step=window_size).mean().plot.line(hue="benchmark")
139116
plt.savefig(os.path.join(plot_dir, "batch_loss_rolling_mean.%dsteps.svg" % window_size), format="svg")
@@ -149,7 +126,7 @@ def plot_loss_rolling_mean(window_size):
149126
plt.savefig(os.path.join(plot_dir, "mean_full_gradient.svg"), format="svg")
150127
plt.close()
151128

152-
print("ready")
129+
logger.info("ready")
153130

154131

155132
def clean(root_dir: str):
@@ -161,7 +138,7 @@ def clean(root_dir: str):
161138
elif os.path.isdir(file_path):
162139
shutil.rmtree(file_path)
163140
except Exception as e:
164-
print(e)
141+
logger.info(e)
165142

166143

167144
def main():
@@ -235,7 +212,7 @@ def main():
235212
elif action == "print_samples":
236213
benchmark_samples = get_benchmark_samples(root_dir)
237214
for smpl in benchmark_samples:
238-
print(smpl)
215+
logger.info(smpl)
239216
elif action == "plot":
240217
plot_benchmark(root_dir)
241218
elif action == "clean":

batchglm/benchmark/nb_glm/performance.py

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22

33
import os
44
import shutil
5+
import logging
56

67
import numpy as np
78
import xarray as xr
89
import yaml
910

10-
# import batchglm.pkg_constants
11+
from .base import load_benchmark_dataset, get_benchmark_samples
12+
1113
from batchglm.api.models.nb_glm import Simulator, Estimator
1214

1315
import batchglm.utils.stats as stat_utils
1416

17+
logger = logging.getLogger(__name__)
18+
1519

1620
def init_benchmark(
1721
root_dir: str,
@@ -99,13 +103,6 @@ def prepare_benchmark_sample(
99103
return sample_config
100104

101105

102-
def get_benchmark_samples(root_dir: str, config_file="config.yml"):
103-
config_file = os.path.join(root_dir, config_file)
104-
with open(config_file, mode="r") as f:
105-
config = yaml.load(f)
106-
return list(config["benchmark_samples"].keys())
107-
108-
109106
def run_benchmark(root_dir: str, sample: str, config_file="config.yml"):
110107
config_file = os.path.join(root_dir, config_file)
111108
with open(config_file, mode="r") as f:
@@ -122,58 +119,31 @@ def run_benchmark(root_dir: str, sample: str, config_file="config.yml"):
122119
init_args = sample_config["init_args"]
123120
init_args["working_dir"] = working_dir
124121

125-
print("loading data...", end="", flush=True)
122+
logger.info("loading data...", end="", flush=True)
126123
sim = Simulator()
127124
sim.load(sim_data_file)
128-
print("\t[OK]")
125+
logger.info("\t[OK]")
129126

130-
print("starting estimation of benchmark sample '%s'..." % sample)
127+
logger.info("starting estimation of benchmark sample '%s'..." % sample)
131128
estimator = Estimator(sim.input_data, batch_size=batch_size)
132129
estimator.initialize(**init_args)
133130
estimator.train(learning_rate=learning_rate)
134-
print("estimation of benchmark sample '%s' ready" % sample)
135-
136-
137-
def load_benchmark_dataset(root_dir: str, config_file="config.yml") -> Tuple[Simulator, xr.Dataset]:
138-
config_file = os.path.join(root_dir, config_file)
139-
with open(config_file, mode="r") as f:
140-
config = yaml.load(f)
141-
142-
sim_data_file = os.path.join(root_dir, config["sim_data"])
143-
sim = Simulator()
144-
sim.load(sim_data_file)
145-
146-
benchmark_samples = config["benchmark_samples"]
147-
benchmark_data = []
148-
for smpl, cfg in benchmark_samples.items():
149-
data = xr.open_mfdataset(
150-
os.path.join(root_dir, cfg["working_dir"], "estimation-*.h5"),
151-
engine="netcdf4",
152-
concat_dim="step",
153-
autoclose=True,
154-
parallel=True,
155-
)
156-
data = data.sortby("global_step")
157-
data.coords["benchmark"] = smpl
158-
benchmark_data.append(data)
159-
benchmark_data = xr.auto_combine(benchmark_data, concat_dim="benchmark", coords="all")
160-
161-
return sim, benchmark_data
131+
logger.info("estimation of benchmark sample '%s' ready" % sample)
162132

163133

164134
def plot_benchmark(root_dir: str, config_file="config.yml"):
165-
print("loading config...", end="", flush=True)
135+
logger.info("loading config...", end="", flush=True)
166136
config_file = os.path.join(root_dir, config_file)
167137
with open(config_file, mode="r") as f:
168138
config = yaml.load(f)
169-
print("\t[OK]")
139+
logger.info("\t[OK]")
170140

171141
plot_dir = os.path.join(root_dir, config["plot_dir"])
172142

173-
print("loading data...", end="", flush=True)
143+
logger.info("loading data...", end="", flush=True)
174144
sim, benchmark_data = load_benchmark_dataset(root_dir)
175145
benchmark_data.coords["time_elapsed"] = benchmark_data.time_elapsed.cumsum("step")
176-
print("\t[OK]")
146+
logger.info("\t[OK]")
177147

178148
import plotnine as pn
179149
import matplotlib.pyplot as plt
@@ -204,7 +174,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
204174
plot = plot + pn.scale_y_log10()
205175
plot.save(os.path.join(plot_dir, name_prefix + ".step.svg"), format="svg")
206176

207-
print("plotting...")
177+
logger.info("plotting...")
208178
val: xr.DataArray = stat_utils.rmsd(
209179
np.exp(xr.DataArray(sim.params["a"][0], dims=("features",))),
210180
np.exp(benchmark_data.a.isel(design_loc_params=0)), axis=[0])
@@ -218,7 +188,7 @@ def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
218188
val: xr.DataArray = benchmark_data.loss
219189
plot_stat(val, "loss", "loss")
220190

221-
print("ready")
191+
logger.info("ready")
222192

223193

224194
def clean(root_dir: str):
@@ -230,7 +200,7 @@ def clean(root_dir: str):
230200
elif os.path.isdir(file_path):
231201
shutil.rmtree(file_path)
232202
except Exception as e:
233-
print(e)
203+
logger.info(e)
234204

235205

236206
def main():
@@ -293,7 +263,7 @@ def main():
293263
elif action == "print_samples":
294264
benchmark_samples = get_benchmark_samples(root_dir)
295265
for smpl in benchmark_samples:
296-
print(smpl)
266+
logger.info(smpl)
297267
elif action == "plot":
298268
plot_benchmark(root_dir)
299269
elif action == "clean":

batchglm/log_cfg.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import sys
2+
3+
import logging
4+
5+
logger = logging.getLogger('.'.join(__name__.split('.')[:-1]))
6+
7+
_is_interactive = bool(getattr(sys, 'ps1', sys.flags.interactive))
8+
_hander = None
9+
10+
11+
def unconfigure_logging():
12+
if _hander is not None:
13+
logger.removeHandler(_hander)
14+
15+
logger.setLevel(logging.NOTSET)
16+
17+
18+
def enable_logging(verbosity=logging.ERROR, stream=sys.stderr, format=logging.BASIC_FORMAT):
19+
unconfigure_logging()
20+
21+
logger.setLevel(verbosity)
22+
_handler = logging.StreamHandler(stream)
23+
_handler.setFormatter(logging.Formatter(format, None))
24+
logger.addHandler(_handler)
25+
26+
27+
# If we are in an interactive environment (like Jupyter), set loglevel to INFO and pipe the output to stdout.
28+
if _is_interactive:
29+
enable_logging(logging.INFO, sys.stdout)
30+
else:
31+
enable_logging(logging.WARNING, sys.stderr)

batchglm/train/tf/nb_glm/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77

88
import numpy as np
99

10-
try:
11-
import anndata
12-
except ImportError:
13-
anndata = None
14-
1510
from .external import AbstractEstimator
1611
from .external import nb_utils
1712
from .external import pkg_constants

0 commit comments

Comments
 (0)