File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed
users/zeyer/experiments/exp2023_04_25_rf Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -31,7 +31,7 @@ def test():
3131 from returnn .util import basic as util
3232 from returnn .datasets .util .vocabulary import Vocabulary
3333 import returnn .frontend as rf
34- from returnn .tensor import Dim
34+ from returnn .tensor import Tensor , Dim
3535 from i6_experiments .users .zeyer .audio .torch .random_speech_like import generate_dummy_train_input_kwargs
3636
3737 better_exchook .install ()
@@ -51,7 +51,23 @@ def test():
5151
5252 train_input_kwargs = generate_dummy_train_input_kwargs (dev = rf .get_default_device (), target_dim = target_dim )
5353
54- # TODO how to setup hooks?
54+ _mod_id_to_name = {}
55+
56+ def _hook (mod , args , kwargs , result ):
57+ name = _mod_id_to_name [id (mod )]
58+ (x ,) = args
59+ x : Tensor
60+ y : Tensor = result
61+ mean , var = rf .moments (x , axis = x .dims )
62+ stddev = rf .sqrt (var )
63+ print ("*" , name , "mean:" )
64+ # TODO rescale...
65+ return y
66+
67+ for name , mod in model .named_modules ():
68+ if isinstance (mod , rf .Linear ):
69+ mod .register_forward_hook (_hook )
70+ _mod_id_to_name [id (mod )] = name
5571
5672 start_time = time .time ()
5773 rf .init_train_step_run_ctx (train_flag = False )
You can’t perform that action at this time.
0 commit comments