Skip to content

Commit 0bfc014

Browse files
committed
more
1 parent e4628d3 commit 0bfc014

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

users/zeyer/experiments/exp2023_04_25_rf/auto_rnd_init.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test():
3131
from returnn.util import basic as util
3232
from returnn.datasets.util.vocabulary import Vocabulary
3333
import returnn.frontend as rf
34-
from returnn.tensor import Dim
34+
from returnn.tensor import Tensor, Dim
3535
from i6_experiments.users.zeyer.audio.torch.random_speech_like import generate_dummy_train_input_kwargs
3636

3737
better_exchook.install()
@@ -51,7 +51,23 @@ def test():
5151

5252
train_input_kwargs = generate_dummy_train_input_kwargs(dev=rf.get_default_device(), target_dim=target_dim)
5353

54-
# TODO how to setup hooks?
54+
_mod_id_to_name = {}
55+
56+
def _hook(mod, args, kwargs, result):
57+
name = _mod_id_to_name[id(mod)]
58+
(x,) = args
59+
x: Tensor
60+
y: Tensor = result
61+
mean, var = rf.moments(x, axis=x.dims)
62+
stddev = rf.sqrt(var)
63+
print("*", name, "mean:")
64+
# TODO rescale...
65+
return y
66+
67+
for name, mod in model.named_modules():
68+
if isinstance(mod, rf.Linear):
69+
mod.register_forward_hook(_hook)
70+
_mod_id_to_name[id(mod)] = name
5571

5672
start_time = time.time()
5773
rf.init_train_step_run_ctx(train_flag=False)

0 commit comments

Comments
 (0)