@@ -32,8 +32,9 @@ def test():
3232 from returnn .config import get_global_config
3333 from returnn .datasets .util .vocabulary import Vocabulary
3434 import returnn .frontend as rf
35- from returnn .tensor import Tensor , Dim , batch_dim
35+ from returnn .tensor import Dim
3636 from returnn .torch .frontend .bridge import rf_module_to_pt_module
37+ from i6_experiments .users .zeyer .audio .torch .random_speech_like import generate_dummy_train_input_kwargs
3738
3839 better_exchook .install ()
3940 config = get_global_config (auto_create = True )
@@ -54,7 +55,7 @@ def test():
5455 rf .set_default_device ("cuda" )
5556 print (f"GPU memory usage (allocated model): { util .human_bytes_size (torch .cuda .memory_allocated (dev ))} " )
5657
57- train_input_kwargs = _generate_dummy_train_input_kwargs (dev = dev , target_dim = target_dim )
58+ train_input_kwargs = generate_dummy_train_input_kwargs (dev = dev , target_dim = target_dim )
5859
5960 # TODO how to setup hooks?
6061
@@ -68,36 +69,3 @@ def test():
6869
6970 for name , loss in rf .get_run_ctx ().losses .items ():
7071 print (f"Loss { name } : { loss .get_mean_loss ().raw_tensor .item ()} " )
71-
72-
73- def _generate_dummy_train_input_kwargs (* , dev : torch .device , target_dim : Dim ) -> Dict [str , Any ]:
74- import torch
75- from returnn .tensor import Tensor , Dim , batch_dim
76- from i6_experiments .users .zeyer .audio .torch .random_speech_like import generate_random_speech_like_audio
77-
78- batch_size = 20
79- batch_dim .dyn_size_ext = Tensor ("batch" , [], dtype = "int32" , raw_tensor = torch .tensor (batch_size , dtype = torch .int32 ))
80- sample_rate = 16_000
81- duration = 10.0
82- num_frames = int (duration * sample_rate )
83- print (
84- f"Using dummy batch of size { batch_size * num_frames } raw frames"
85- f" ({ batch_size * num_frames * 100 // sample_rate } 10ms frames)"
86- )
87- audio_raw = generate_random_speech_like_audio (batch_size , num_frames , samples_per_sec = sample_rate )
88- audio_raw = audio_raw .to (dev )
89- audio_lens_raw = torch .tensor ([num_frames ] * batch_size , dtype = torch .int32 )
90- audio_spatial_dim = Dim (Tensor ("time" , [batch_dim ], dtype = "int32" , raw_tensor = audio_lens_raw ))
91- audio = Tensor ("audio" , [batch_dim , audio_spatial_dim ], dtype = "float32" , raw_tensor = audio_raw )
92-
93- targets_len = int (duration * 3 )
94- targets_lens_raw = torch .tensor ([targets_len ] * batch_size , dtype = torch .int32 )
95- targets_spatial_dim = Dim (Tensor ("targets_len" , [batch_dim ], dtype = "int32" , raw_tensor = targets_lens_raw ))
96- targets_raw = torch .randint (0 , target_dim .dimension , size = (batch_size , targets_len ), dtype = torch .int32 , device = dev )
97- targets = Tensor (
98- "targets" , [batch_dim , targets_spatial_dim ], dtype = "int32" , sparse_dim = target_dim , raw_tensor = targets_raw
99- )
100-
101- return dict (
102- data = audio , data_spatial_dim = audio_spatial_dim , targets = targets , targets_spatial_dim = targets_spatial_dim
103- )
0 commit comments