1313
1414"""
1515
16-
16+ from __future__ import annotations
17+ from typing import TYPE_CHECKING , Any , Dict
1718import time
1819
1920from .aed import from_scratch_model_def , Model , from_scratch_training
2021
22+ if TYPE_CHECKING :
23+ import torch
24+ from returnn .tensor import Dim
25+
2126
2227def test ():
2328 import torch
@@ -49,6 +54,25 @@ def test():
4954 rf .set_default_device ("cuda" )
5055 print (f"GPU memory usage (allocated model): { util .human_bytes_size (torch .cuda .memory_allocated (dev ))} " )
5156
57+ train_input_kwargs = _generate_dummy_train_input_kwargs (dev = dev , target_dim = target_dim )
58+
59+ # TODO how to setup hooks?
60+
61+ start_time = time .time ()
62+ rf .init_train_step_run_ctx (train_flag = False )
63+ pt_model .eval ()
64+ with torch .no_grad ():
65+ from_scratch_training (model = model , ** train_input_kwargs )
66+ print ("One train forward step, duration:" , util .hms_fraction (time .time () - start_time ), "sec" )
67+ print (f"GPU peak memory allocated: { util .human_bytes_size (torch .cuda .max_memory_allocated (dev ))} " )
68+
69+ for name , loss in rf .get_run_ctx ().losses .items ():
70+ print (f"Loss { name } : { loss .get_mean_loss ().raw_tensor .item ()} " )
71+
72+
73+ def _generate_dummy_train_input_kwargs (* , dev : torch .device , target_dim : Dim ) -> Dict [str , Any ]:
74+ import torch
75+ from returnn .tensor import Tensor , Dim , batch_dim
5276 from i6_experiments .users .zeyer .audio .torch .random_speech_like import generate_random_speech_like_audio
5377
5478 batch_size = 20
@@ -57,7 +81,7 @@ def test():
5781 duration = 10.0
5882 num_frames = int (duration * sample_rate )
5983 print (
60- f"Using dummy batch of size { batch_size * num_frames } raw frames, "
84+ f"Using dummy batch of size { batch_size * num_frames } raw frames"
6185 f" ({ batch_size * num_frames * 100 // sample_rate } 10ms frames)"
6286 )
6387 audio_raw = generate_random_speech_like_audio (batch_size , num_frames , samples_per_sec = sample_rate )
@@ -69,27 +93,11 @@ def test():
6993 targets_len = int (duration * 3 )
7094 targets_lens_raw = torch .tensor ([targets_len ] * batch_size , dtype = torch .int32 )
7195 targets_spatial_dim = Dim (Tensor ("targets_len" , [batch_dim ], dtype = "int32" , raw_tensor = targets_lens_raw ))
72- targets_raw = torch .randint (
73- 0 , model .target_dim .dimension , size = (batch_size , targets_len ), dtype = torch .int32 , device = dev
74- )
96+ targets_raw = torch .randint (0 , target_dim .dimension , size = (batch_size , targets_len ), dtype = torch .int32 , device = dev )
7597 targets = Tensor (
76- "targets" , [batch_dim , targets_spatial_dim ], dtype = "int32" , sparse_dim = model . target_dim , raw_tensor = targets_raw
98+ "targets" , [batch_dim , targets_spatial_dim ], dtype = "int32" , sparse_dim = target_dim , raw_tensor = targets_raw
7799 )
78100
79- start_time = time .time ()
80- rf .init_train_step_run_ctx (train_flag = False )
81- pt_model .eval ()
82- with torch .no_grad ():
83- # TODO how to setup hooks?
84- from_scratch_training (
85- model = model ,
86- data = audio ,
87- data_spatial_dim = audio_spatial_dim ,
88- targets = targets ,
89- targets_spatial_dim = targets_spatial_dim ,
90- )
91- print ("One train forward step, duration:" , util .hms_fraction (time .time () - start_time ), "sec" )
92- print (f"GPU peak memory allocated: { util .human_bytes_size (torch .cuda .max_memory_allocated (dev ))} " )
93-
94- for name , loss in rf .get_run_ctx ().losses .items ():
95- print (f"Loss { name } : { loss .get_mean_loss ().raw_tensor .item ()} " )
101+ return dict (
102+ data = audio , data_spatial_dim = audio_spatial_dim , targets = targets , targets_spatial_dim = targets_spatial_dim
103+ )
0 commit comments