@@ -46,16 +46,20 @@ def test():
4646 pt_model = rf_module_to_pt_module (model )
4747 dev = torch .device ("cuda" )
4848 pt_model .to (dev )
49- pt_model .eval ()
5049 rf .set_default_device ("cuda" )
50+ print (f"GPU memory usage (allocated model): { util .human_bytes_size (torch .cuda .memory_allocated (dev ))} " )
5151
5252 from i6_experiments .users .zeyer .audio .torch .random_speech_like import generate_random_speech_like_audio
5353
54- batch_size = 10
54+ batch_size = 20
5555 batch_dim .dyn_size_ext = Tensor ("batch" , [], dtype = "int32" , raw_tensor = torch .tensor (batch_size , dtype = torch .int32 ))
5656 sample_rate = 16_000
57- duration = 5 .0
57+ duration = 10 .0
5858 num_frames = int (duration * sample_rate )
59+ print (
60+ f"Using dummy batch of size { batch_size * num_frames } raw frames,"
61+ f" ({ batch_size * num_frames * 100 // sample_rate } 10ms frames)"
62+ )
5963 audio_raw = generate_random_speech_like_audio (batch_size , num_frames , samples_per_sec = sample_rate )
6064 audio_raw = audio_raw .to (dev )
6165 audio_lens_raw = torch .tensor ([num_frames ] * batch_size , dtype = torch .int32 )
@@ -65,22 +69,27 @@ def test():
6569 targets_len = int (duration * 3 )
6670 targets_lens_raw = torch .tensor ([targets_len ] * batch_size , dtype = torch .int32 )
6771 targets_spatial_dim = Dim (Tensor ("targets_len" , [batch_dim ], dtype = "int32" , raw_tensor = targets_lens_raw ))
68- targets_raw = torch .randint (0 , model .target_dim .dimension , size = (batch_size , targets_len ), dtype = torch .int32 , device = dev )
72+ targets_raw = torch .randint (
73+ 0 , model .target_dim .dimension , size = (batch_size , targets_len ), dtype = torch .int32 , device = dev
74+ )
6975 targets = Tensor (
7076 "targets" , [batch_dim , targets_spatial_dim ], dtype = "int32" , sparse_dim = model .target_dim , raw_tensor = targets_raw
7177 )
7278
7379 start_time = time .time ()
7480 rf .init_train_step_run_ctx (train_flag = False )
75- # TODO how to setup hooks?
76- from_scratch_training (
77- model = model ,
78- data = audio ,
79- data_spatial_dim = audio_spatial_dim ,
80- targets = targets ,
81- targets_spatial_dim = targets_spatial_dim ,
82- )
81+ pt_model .eval ()
82+ with torch .no_grad ():
83+ # TODO how to setup hooks?
84+ from_scratch_training (
85+ model = model ,
86+ data = audio ,
87+ data_spatial_dim = audio_spatial_dim ,
88+ targets = targets ,
89+ targets_spatial_dim = targets_spatial_dim ,
90+ )
8391 print ("One train forward step, duration:" , util .hms_fraction (time .time () - start_time ), "sec" )
92+ print (f"GPU peak memory allocated: { util .human_bytes_size (torch .cuda .max_memory_allocated (dev ))} " )
8493
8594 for name , loss in rf .get_run_ctx ().losses .items ():
8695 print (f"Loss { name } : { loss .get_mean_loss ().raw_tensor .item ()} " )
0 commit comments