Skip to content

Commit 1af08a0

Browse files
committed
Merge remote-tracking branch 'remotes/theislab/dev' into dev
2 parents fd9e61e + 2285375 commit 1af08a0

File tree

3 files changed

+47
-77
lines changed

3 files changed

+47
-77
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ You can install [tensorflow](https://www.tensorflow.org/install/) via pip or via
1717

1818
#### pip
1919
- CPU-only: <br/>
20-
`pip install tf-nightly`
20+
`pip install tensorflow`
2121
- GPU: <br/>
22-
`pip install tf-nightly-gpu`
22+
`pip install tensorflow-gpu`
2323

2424
### Hardware-optimized tensorflow installation (compiling from source)
2525
Please refer to https://www.tensorflow.org/install/.

batchglm/train/tf/nb_glm/estimator.py

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -585,104 +585,72 @@ class TrainingStrategy(Enum):
585585
AUTO = None
586586
DEFAULT = [
587587
{
588-
"learning_rate": 0.1,
589-
"convergence_criteria": "t_test",
590-
"stopping_criteria": 0.05,
591-
"loss_window_size": 100,
592-
"use_batching": True,
588+
"learning_rate": 0.5,
589+
"convergence_criteria": "scaled_moving_average",
590+
"stopping_criteria": 1e-5,
591+
"loss_window_size": 10,
592+
"use_batching": False,
593593
"optim_algo": "ADAM",
594594
},
595595
{
596-
"learning_rate": 0.05,
597-
"convergence_criteria": "t_test",
598-
"stopping_criteria": 0.05,
596+
"convergence_criteria": "scaled_moving_average",
597+
"stopping_criteria": 1e-10,
599598
"loss_window_size": 10,
600599
"use_batching": False,
601-
"optim_algo": "ADAM",
600+
"optim_algo": "newton",
602601
},
603602
]
604603
EXACT = [
605604
{
606-
"learning_rate": 0.1,
607-
"convergence_criteria": "t_test",
608-
"stopping_criteria": 0.05,
609-
"loss_window_size": 100,
610-
"use_batching": True,
611-
"optim_algo": "ADAM",
612-
},
613-
{
614-
"learning_rate": 0.05,
615-
"convergence_criteria": "t_test",
616-
"stopping_criteria": 0.05,
617-
"loss_window_size": 100,
618-
"use_batching": True,
619-
"optim_algo": "ADAM",
620-
},
621-
{
622-
"learning_rate": 0.005,
623-
"convergence_criteria": "t_test",
624-
"stopping_criteria": 0.25,
605+
"learning_rate": 0.5,
606+
"convergence_criteria": "scaled_moving_average",
607+
"stopping_criteria": 1e-5,
625608
"loss_window_size": 10,
626609
"use_batching": False,
627-
"optim_algo": "Newton-Raphson",
628-
},
629-
]
630-
QUICK = [
631-
{
632-
"learning_rate": 0.1,
633-
"convergence_criteria": "t_test",
634-
"stopping_criteria": 0.05,
635-
"loss_window_size": 100,
636-
"use_batching": True,
637610
"optim_algo": "ADAM",
638611
},
639-
]
640-
PRE_INITIALIZED = [
641612
{
642-
"learning_rate": 0.01,
643-
"convergence_criteria": "t_test",
644-
"stopping_criteria": 0.25,
613+
"convergence_criteria": "scaled_moving_average",
614+
"stopping_criteria": 1e-10,
645615
"loss_window_size": 10,
646616
"use_batching": False,
647-
"optim_algo": "ADAM",
617+
"optim_algo": "newton",
648618
},
649619
]
650-
NEWTON_EXACT = [
620+
QUICK = [
651621
{
652-
"learning_rate": 1,
622+
"learning_rate": 0.5,
653623
"convergence_criteria": "scaled_moving_average",
654624
"stopping_criteria": 1e-8,
655-
"loss_window_size": 5,
625+
"loss_window_size": 10,
656626
"use_batching": False,
657-
"optim_algo": "newton-raphson",
627+
"optim_algo": "ADAM",
658628
},
659629
]
660-
NEWTON_BATCHED = [
630+
PRE_INITIALIZED = [
661631
{
662-
"learning_rate": 1,
663632
"convergence_criteria": "scaled_moving_average",
664-
"stopping_criteria": 1e-8,
665-
"loss_window_size": 20,
666-
"use_batching": True,
667-
"optim_algo": "newton-raphson",
633+
"stopping_criteria": 1e-10,
634+
"loss_window_size": 10,
635+
"use_batching": False,
636+
"optim_algo": "newton",
668637
},
669638
]
670-
NEWTON_SERIES = [
639+
CONTINUOUS = [
671640
{
672-
"learning_rate": 1,
641+
"learning_rate": 0.5,
673642
"convergence_criteria": "scaled_moving_average",
674-
"stopping_criteria": 1e-8,
675-
"loss_window_size": 8,
676-
"use_batching": True,
677-
"optim_algo": "newton-raphson",
643+
"stopping_criteria": 1e-5,
644+
"loss_window_size": 10,
645+
"use_batching": False,
646+
"optim_algo": "ADAM",
678647
},
679648
{
680-
"learning_rate": 1,
681649
"convergence_criteria": "scaled_moving_average",
682-
"stopping_criteria": 1e-8,
683-
"loss_window_size": 4,
650+
"stopping_criteria": 1e-10,
651+
"loss_window_size": 10,
684652
"use_batching": False,
685-
"optim_algo": "newton-raphson",
653+
"optim_algo": "newton",
686654
},
687655
]
688656

@@ -935,11 +903,8 @@ def __init__(
935903
my_loc_names = set(input_data.design_loc_names.values)
936904
my_loc_names = my_loc_names.intersection(init_model.input_data.design_loc_names.values)
937905

938-
init_loc = np.random.uniform(
939-
low=np.nextafter(0, 1, dtype=input_data.X.dtype),
940-
high=np.sqrt(np.nextafter(0, 1, dtype=input_data.X.dtype)),
941-
size=(input_data.num_design_loc_params, input_data.num_features)
942-
)
906+
# Initialize new parameters to zero:
907+
init_loc = np.zeros(shape=(input_data.num_design_loc_params, input_data.num_features))
943908
for parm in my_loc_names:
944909
init_idx = np.where(init_model.input_data.design_loc_names == parm)
945910
my_idx = np.where(input_data.design_loc_names == parm)
@@ -952,11 +917,8 @@ def __init__(
952917
my_scale_names = set(input_data.design_scale_names.values)
953918
my_scale_names = my_scale_names.intersection(init_model.input_data.design_scale_names.values)
954919

955-
init_scale = np.random.uniform(
956-
low=np.nextafter(0, 1, dtype=input_data.X.dtype),
957-
high=np.sqrt(np.nextafter(0, 1, dtype=input_data.X.dtype)),
958-
size=(input_data.num_design_scale_params, input_data.num_features)
959-
)
920+
# Initialize new parameters to zero:
921+
init_scale = np.zeros(shape=(input_data.num_design_scale_params, input_data.num_features))
960922
for parm in my_scale_names:
961923
init_idx = np.where(init_model.input_data.design_scale_names == parm)
962924
my_idx = np.where(input_data.design_scale_names == parm)

setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
import versioneer
33

44
author = 'Florian R. Hölzlwimmer, David S. Fischer'
5+
6+
description="Fast and scalable fitting of over-determined generalized-linear models (GLMs)"
7+
8+
with open("README.md", "r") as fh:
9+
long_description = fh.read()
510

611
setup(
712
name='batchglm',
813
author=author,
9-
author_email='[email protected]',
14+
author_email=author_email,
15+
description=description,
16+
long_description=long_description,
17+
long_description_content_type="text/markdown",
1018
packages=find_packages(),
1119
install_requires=[
1220
'tensorflow>=1.10.0',

0 commit comments

Comments
 (0)