Skip to content

Commit 5e8c0d1

Browse files
Merge branch 'dev' of https://github.com/theislab/batchglm into dev
2 parents 86f4e5d + 195447e commit 5e8c0d1

File tree

28 files changed

+1192
-421
lines changed

28 files changed

+1192
-421
lines changed

batchglm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
__version__ = get_versions()['version']
44
del get_versions
55

6-
from .log_cfg import logger, unconfigure_logging, enable_logging
6+
from .log_cfg import logger, unconfigure_logging, setup_logging

batchglm/api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .. import __version__
2-
from ..log_cfg import logger, unconfigure_logging, enable_logging
2+
from ..log_cfg import logger, unconfigure_logging, setup_logging
33

44
from . import models
55
from . import data

batchglm/api/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from . import stats
22
from . import random
3+
from . import numeric
4+
from . import linalg

batchglm/api/utils/linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from batchglm.utils.linalg import stacked_lstsq, groupwise_solve_lm

batchglm/api/utils/numeric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from batchglm.utils.numeric import combine_matrices, softmax, weighted_mean, weighted_variance

batchglm/benchmark/nb_glm/base.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import logging
5+
import traceback
56

67
from collections import OrderedDict
78
import itertools
@@ -14,6 +15,7 @@
1415

1516
logger = logging.getLogger(__name__)
1617

18+
1719
def init_benchmark(
1820
root_dir: str,
1921
sim: Simulator,
@@ -110,10 +112,16 @@ def prepare_benchmark_sample(
110112
return sample_config
111113

112114

113-
def get_benchmark_samples(root_dir: str, config_file="config.yml"):
115+
def load_config(root_dir, config_file):
114116
config_file = os.path.join(root_dir, config_file)
115117
with open(config_file, mode="r") as f:
116118
config = yaml.load(f)
119+
120+
return config
121+
122+
123+
def get_benchmark_samples(root_dir: str, config_file="config.yml"):
124+
config = load_config(root_dir, config_file)
117125
return list(config["benchmark_samples"].keys())
118126

119127

@@ -186,18 +194,29 @@ def load_benchmark_dataset(root_dir: str, config_file="config.yml") -> Tuple[Sim
186194
benchmark_data = []
187195
for smpl, cfg in benchmark_samples.items():
188196
wd = cfg["working_dir"]
189-
logger.info("opening working dir: %s", wd)
197+
190198
ds_path = os.path.join(root_dir, wd, "cache.zarr")
199+
ds_cache_OK = os.path.join(root_dir, wd, "cache_OK")
200+
201+
logger.info("opening working dir: %s", wd)
191202
try: # try open zarr cache
203+
if not os.path.exists(ds_cache_OK):
204+
raise FileNotFoundError
205+
192206
data = xr.open_zarr(ds_path)
193207
logger.info("using zarr cache: %s", os.path.join(wd, "cache.zarr"))
194-
except: # open netcdf4 files
208+
except BaseException as e: # open netcdf4 files
209+
if isinstance(e, FileNotFoundError):
210+
pass
211+
else:
212+
traceback.print_exc()
213+
195214
logger.info("loading step-wise netcdf4 files...")
196215
ncdf_data = xr.open_mfdataset(
197216
os.path.join(root_dir, cfg["working_dir"], "estimation-*.h5"),
198-
engine="netcdf4",
217+
engine="h5netcdf",
199218
concat_dim="step",
200-
autoclose=True,
219+
# autoclose=True,
201220
parallel=True,
202221
)
203222
ncdf_data = ncdf_data.sortby("global_step")
@@ -206,13 +225,18 @@ def load_benchmark_dataset(root_dir: str, config_file="config.yml") -> Tuple[Sim
206225

207226
try: # try to save data in zarr cache
208227
zarr_data = ncdf_data.to_zarr(ds_path)
228+
touch(ds_cache_OK)
229+
209230
logger.info("Stored data in zarr cache")
210231

211232
# close netcdf4 data sets
212233
ncdf_data.close()
213234
del ncdf_data
214235
data = zarr_data
215-
except: # use netcdf4 since zarr does not seem to work
236+
except BaseException as e: # use netcdf4 since zarr does not seem to work
237+
traceback.print_exc()
238+
239+
logger.info("falling back to step-wise netcdf4 store")
216240
data = ncdf_data
217241

218242
benchmark_data.append(data)

batchglm/benchmark/nb_glm/convergence.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,30 @@
44
import shutil
55
import logging
66

7+
import pandas as pd
78
import scipy.stats
89
import numpy as np
910
import xarray as xr
1011
import yaml
1112

12-
from .base import init_benchmark, get_benchmark_samples, run_benchmark, load_benchmark_dataset
13+
from .base import init_benchmark, get_benchmark_samples, run_benchmark, load_benchmark_dataset, load_config
1314
from .base import Simulator
1415

1516
import batchglm.utils.stats as stat_utils
1617

1718
logger = logging.getLogger(__name__)
1819

1920

20-
def plot_benchmark(root_dir: str, config_file="config.yml"):
21+
def group_by_training(benchmark_configs, keys=("optim_algo", "learning_rate")):
22+
benchmark_df = pd.DataFrame.from_dict({
23+
benchmark: [cfg["training_args"][k] for k in keys]
24+
for benchmark, cfg in benchmark_configs.items()
25+
}, orient='index', columns=keys)
26+
27+
return benchmark_df
28+
29+
30+
def plot_all_benchmarks(root_dir, config_file="config.yml"):
2131
logger.info("loading config...", end="", flush=True)
2232
config_file = os.path.join(root_dir, config_file)
2333
with open(config_file, mode="r") as f:
@@ -28,48 +38,71 @@ def plot_benchmark(root_dir: str, config_file="config.yml"):
2838

2939
logger.info("loading data...", end="", flush=True)
3040
sim, benchmark_data = load_benchmark_dataset(root_dir)
31-
benchmark_data.coords["time_elapsed"] = benchmark_data.time_elapsed.cumsum("step")
3241
logger.info("\t[OK]")
3342

43+
plot_benchmarks(
44+
plot_dir=plot_dir,
45+
sim=sim,
46+
benchmark_data=benchmark_data,
47+
benchmark_names=benchmark_data.coords["benchmark"]
48+
)
49+
50+
51+
def plot_benchmarks(plot_dir: str, sim, benchmark_data, benchmark_names):
52+
benchmark_data = benchmark_data.assign_coords(**{
53+
"time_elapsed": benchmark_data.time_elapsed.cumsum("step"),
54+
})
55+
benchmark_data.coords["benchmark"] = xr.DataArray(
56+
dims=("benchmark",),
57+
data=benchmark_names
58+
)
59+
60+
groupby_col = "benchmark"
61+
linewidth = 0.5
62+
3463
import plotnine as pn
3564
import matplotlib.pyplot as plt
65+
plt.rcParams["legend.loc"] = "center left"
3666

3767
from dask.diagnostics import ProgressBar
3868

3969
def plot_stat(val, val_name, name_prefix, scale_y_log10=False):
4070
with ProgressBar():
41-
df = val.to_dataframe(val_name).reset_index()
71+
df = val.to_dataframe(val_name)
72+
df = df.reset_index()
4273

4374
plot = (pn.ggplot(df)
44-
+ pn.aes(x="time_elapsed", y=val_name, group="benchmark", color="benchmark")
75+
+ pn.aes(x="time_elapsed", y=val_name, group=groupby_col, color=groupby_col)
4576
+ pn.geom_line()
46-
+ pn.geom_vline(xintercept=df.location[[np.argmin(df[val_name])]].time_elapsed.values[0], color="black")
77+
+ pn.geom_vline(xintercept=df.loc[[np.argmin(df[val_name])]].time_elapsed.values[0], color="black")
4778
+ pn.geom_hline(yintercept=np.min(df[val_name]), alpha=0.5)
4879
)
4980
if scale_y_log10:
5081
plot = plot + pn.scale_y_log10()
5182
plot.save(os.path.join(plot_dir, name_prefix + ".time.svg"), format="svg")
5283

5384
plot = (pn.ggplot(df)
54-
+ pn.aes(x="global_step", y=val_name, group="benchmark", color="benchmark")
85+
+ pn.aes(x="global_step", y=val_name, group=groupby_col, color=groupby_col)
5586
+ pn.geom_line()
56-
+ pn.geom_vline(xintercept=df.location[[np.argmin(df[val_name])]].global_step.values[0], color="black")
87+
+ pn.geom_vline(xintercept=df.loc[[np.argmin(df[val_name])]].global_step.values[0], color="black")
5788
+ pn.geom_hline(yintercept=np.min(df[val_name]), alpha=0.5)
5889
)
5990
if scale_y_log10:
6091
plot = plot + pn.scale_y_log10()
6192
plot.save(os.path.join(plot_dir, name_prefix + ".step.svg"), format="svg")
6293

94+
return df
95+
6396
logger.info("plotting...")
6497
val: xr.DataArray = stat_utils.rmsd(
6598
np.exp(xr.DataArray(sim.params["a"][0], dims=("features",))),
6699
np.exp(benchmark_data.a.isel(design_loc_params=0)), axis=[0])
67-
plot_stat(val, "mapd", "real_mu")
100+
df = plot_stat(val, "mapd", "real_mu")
68101

69102
val: xr.DataArray = stat_utils.rmsd(
70103
np.exp(xr.DataArray(sim.params["b"][0], dims=("features",))),
71104
np.exp(benchmark_data.b.isel(design_scale_params=0)), axis=[0])
72-
plot_stat(val, "mapd", "real_r")
105+
df = plot_stat(val, "mapd", "real_r")
73106

74107
val: xr.DataArray = benchmark_data.loss
75108
plot_stat(val, "loss", "loss")
@@ -90,41 +123,65 @@ def plot_pval(window_size):
90123
t = t[:, window_size:]
91124
df = df[:, window_size:]
92125

93-
pval = t.copy()
94-
pval[:, :] = scipy.stats.t(df).cdf(t)
95-
pval.plot.line(hue="benchmark")
96-
plt.savefig(os.path.join(plot_dir, "pval_convergence.%dsteps.svg" % window_size), format="svg")
97-
# plt.show()
98-
plt.close()
126+
pval = xr.DataArray(
127+
name="pval",
128+
data=scipy.stats.t(df).cdf(t),
129+
dims=t.dims,
130+
coords=t.coords
131+
)
132+
133+
fig, ax = plt.subplots()
134+
lines = pval.plot.line(hue=groupby_col, linewidth=linewidth, ax=ax)
135+
ax.get_legend().set_bbox_to_anchor((1, 0.5))
136+
fig.savefig(os.path.join(plot_dir, "pval_convergence.%dsteps.svg" % window_size),
137+
format="svg", bbox_inches='tight')
138+
# fig.show()
139+
plt.close(fig)
99140

141+
plot_pval(25)
142+
plot_pval(50)
100143
plot_pval(100)
101144
plot_pval(200)
102145
plot_pval(400)
103146

104-
benchmark_data.full_loss.plot.line(hue="benchmark")
105-
plt.savefig(os.path.join(plot_dir, "full_loss.svg"), format="svg")
106-
plt.close()
147+
fig, ax = plt.subplots()
148+
lines = benchmark_data.full_loss.plot.line(hue=groupby_col, linewidth=linewidth, ax=ax)
149+
ax.set_ylabel('full loss')
150+
ax.get_legend().set_bbox_to_anchor((1, 0.5))
151+
fig.savefig(os.path.join(plot_dir, "full_loss.svg"), format="svg", bbox_inches='tight')
152+
plt.close(fig)
107153

108-
benchmark_data.loss.plot.line(hue="benchmark")
109-
plt.savefig(os.path.join(plot_dir, "batch_loss.svg"), format="svg")
110-
plt.close()
154+
fig, ax = plt.subplots()
155+
lines = benchmark_data.loss.plot.line(hue=groupby_col, linewidth=linewidth, ax=ax)
156+
ax.set_ylabel('batch loss')
157+
ax.get_legend().set_bbox_to_anchor((1, 0.5))
158+
fig.savefig(os.path.join(plot_dir, "batch_loss.svg"), format="svg", bbox_inches='tight')
159+
plt.close(fig)
111160

112161
def plot_loss_rolling_mean(window_size):
113162
logger.info("plotting rolling mean of batch loss with window size: %d" % window_size)
114163

115-
benchmark_data.loss.rolling(step=window_size).mean().plot.line(hue="benchmark")
116-
plt.savefig(os.path.join(plot_dir, "batch_loss_rolling_mean.%dsteps.svg" % window_size), format="svg")
117-
plt.close()
164+
fig, ax = plt.subplots()
165+
lines = benchmark_data.loss.rolling(step=window_size).mean().plot.line(
166+
hue=groupby_col, linewidth=linewidth, ax=ax)
167+
ax.set_ylabel('rolling mean')
168+
ax.get_legend().set_bbox_to_anchor((1, 0.5))
169+
fig.savefig(os.path.join(plot_dir, "batch_loss_rolling_mean.%dsteps.svg" % window_size),
170+
format="svg", bbox_inches='tight')
171+
plt.close(fig)
118172

119173
plot_loss_rolling_mean(25)
120174
plot_loss_rolling_mean(50)
121175
plot_loss_rolling_mean(100)
122176
plot_loss_rolling_mean(200)
123177

178+
fig, ax = plt.subplots()
124179
with ProgressBar():
125-
benchmark_data.full_gradient.mean(dim="features").plot.line(hue="benchmark")
126-
plt.savefig(os.path.join(plot_dir, "mean_full_gradient.svg"), format="svg")
127-
plt.close()
180+
lines = benchmark_data.full_gradient.mean(dim="features").plot.line(
181+
hue=groupby_col, linewidth=linewidth, ax=ax)
182+
ax.get_legend().set_bbox_to_anchor((1, 0.5))
183+
fig.savefig(os.path.join(plot_dir, "mean_full_gradient.svg"), format="svg", bbox_inches='tight')
184+
plt.close(fig)
128185

129186
logger.info("ready")
130187

@@ -214,7 +271,7 @@ def main():
214271
for smpl in benchmark_samples:
215272
logger.info(smpl)
216273
elif action == "plot":
217-
plot_benchmark(root_dir)
274+
plot_all_benchmarks(root_dir)
218275
elif action == "clean":
219276
clean(root_dir)
220277

0 commit comments

Comments
 (0)