Skip to content

Commit 724bdee

Browse files
committed
more
1 parent e62119c commit 724bdee

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

users/zeyer/experiments/exp2023_04_25_rf/auto_rnd_init.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
1414
"""
1515

16-
16+
from __future__ import annotations
17+
from typing import TYPE_CHECKING, Any, Dict
1718
import time
1819

1920
from .aed import from_scratch_model_def, Model, from_scratch_training
2021

22+
if TYPE_CHECKING:
23+
import torch
24+
from returnn.tensor import Dim
25+
2126

2227
def test():
2328
import torch
@@ -49,6 +54,25 @@ def test():
4954
rf.set_default_device("cuda")
5055
print(f"GPU memory usage (allocated model): {util.human_bytes_size(torch.cuda.memory_allocated(dev))}")
5156

57+
train_input_kwargs = _generate_dummy_train_input_kwargs(dev=dev, target_dim=target_dim)
58+
59+
# TODO how to setup hooks?
60+
61+
start_time = time.time()
62+
rf.init_train_step_run_ctx(train_flag=False)
63+
pt_model.eval()
64+
with torch.no_grad():
65+
from_scratch_training(model=model, **train_input_kwargs)
66+
print("One train forward step, duration:", util.hms_fraction(time.time() - start_time), "sec")
67+
print(f"GPU peak memory allocated: {util.human_bytes_size(torch.cuda.max_memory_allocated(dev))}")
68+
69+
for name, loss in rf.get_run_ctx().losses.items():
70+
print(f"Loss {name}: {loss.get_mean_loss().raw_tensor.item()}")
71+
72+
73+
def _generate_dummy_train_input_kwargs(*, dev: torch.device, target_dim: Dim) -> Dict[str, Any]:
74+
import torch
75+
from returnn.tensor import Tensor, Dim, batch_dim
5276
from i6_experiments.users.zeyer.audio.torch.random_speech_like import generate_random_speech_like_audio
5377

5478
batch_size = 20
@@ -57,7 +81,7 @@ def test():
5781
duration = 10.0
5882
num_frames = int(duration * sample_rate)
5983
print(
60-
f"Using dummy batch of size {batch_size * num_frames} raw frames,"
84+
f"Using dummy batch of size {batch_size * num_frames} raw frames"
6185
f" ({batch_size * num_frames * 100 // sample_rate} 10ms frames)"
6286
)
6387
audio_raw = generate_random_speech_like_audio(batch_size, num_frames, samples_per_sec=sample_rate)
@@ -69,27 +93,11 @@ def test():
6993
targets_len = int(duration * 3)
7094
targets_lens_raw = torch.tensor([targets_len] * batch_size, dtype=torch.int32)
7195
targets_spatial_dim = Dim(Tensor("targets_len", [batch_dim], dtype="int32", raw_tensor=targets_lens_raw))
72-
targets_raw = torch.randint(
73-
0, model.target_dim.dimension, size=(batch_size, targets_len), dtype=torch.int32, device=dev
74-
)
96+
targets_raw = torch.randint(0, target_dim.dimension, size=(batch_size, targets_len), dtype=torch.int32, device=dev)
7597
targets = Tensor(
76-
"targets", [batch_dim, targets_spatial_dim], dtype="int32", sparse_dim=model.target_dim, raw_tensor=targets_raw
98+
"targets", [batch_dim, targets_spatial_dim], dtype="int32", sparse_dim=target_dim, raw_tensor=targets_raw
7799
)
78100

79-
start_time = time.time()
80-
rf.init_train_step_run_ctx(train_flag=False)
81-
pt_model.eval()
82-
with torch.no_grad():
83-
# TODO how to setup hooks?
84-
from_scratch_training(
85-
model=model,
86-
data=audio,
87-
data_spatial_dim=audio_spatial_dim,
88-
targets=targets,
89-
targets_spatial_dim=targets_spatial_dim,
90-
)
91-
print("One train forward step, duration:", util.hms_fraction(time.time() - start_time), "sec")
92-
print(f"GPU peak memory allocated: {util.human_bytes_size(torch.cuda.max_memory_allocated(dev))}")
93-
94-
for name, loss in rf.get_run_ctx().losses.items():
95-
print(f"Loss {name}: {loss.get_mean_loss().raw_tensor.item()}")
101+
return dict(
102+
data=audio, data_spatial_dim=audio_spatial_dim, targets=targets, targets_spatial_dim=targets_spatial_dim
103+
)

0 commit comments

Comments
 (0)