Skip to content

Commit e62119c

Browse files
committed
more
1 parent e827e9f commit e62119c

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

users/zeyer/experiments/exp2023_04_25_rf/auto_rnd_init.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,20 @@ def test():
4646
pt_model = rf_module_to_pt_module(model)
4747
dev = torch.device("cuda")
4848
pt_model.to(dev)
49-
pt_model.eval()
5049
rf.set_default_device("cuda")
50+
print(f"GPU memory usage (allocated model): {util.human_bytes_size(torch.cuda.memory_allocated(dev))}")
5151

5252
from i6_experiments.users.zeyer.audio.torch.random_speech_like import generate_random_speech_like_audio
5353

54-
batch_size = 10
54+
batch_size = 20
5555
batch_dim.dyn_size_ext = Tensor("batch", [], dtype="int32", raw_tensor=torch.tensor(batch_size, dtype=torch.int32))
5656
sample_rate = 16_000
57-
duration = 5.0
57+
duration = 10.0
5858
num_frames = int(duration * sample_rate)
59+
print(
60+
f"Using dummy batch of size {batch_size * num_frames} raw frames,"
61+
f" ({batch_size * num_frames * 100 // sample_rate} 10ms frames)"
62+
)
5963
audio_raw = generate_random_speech_like_audio(batch_size, num_frames, samples_per_sec=sample_rate)
6064
audio_raw = audio_raw.to(dev)
6165
audio_lens_raw = torch.tensor([num_frames] * batch_size, dtype=torch.int32)
@@ -65,22 +69,27 @@ def test():
6569
targets_len = int(duration * 3)
6670
targets_lens_raw = torch.tensor([targets_len] * batch_size, dtype=torch.int32)
6771
targets_spatial_dim = Dim(Tensor("targets_len", [batch_dim], dtype="int32", raw_tensor=targets_lens_raw))
68-
targets_raw = torch.randint(0, model.target_dim.dimension, size=(batch_size, targets_len), dtype=torch.int32, device=dev)
72+
targets_raw = torch.randint(
73+
0, model.target_dim.dimension, size=(batch_size, targets_len), dtype=torch.int32, device=dev
74+
)
6975
targets = Tensor(
7076
"targets", [batch_dim, targets_spatial_dim], dtype="int32", sparse_dim=model.target_dim, raw_tensor=targets_raw
7177
)
7278

7379
start_time = time.time()
7480
rf.init_train_step_run_ctx(train_flag=False)
75-
# TODO how to setup hooks?
76-
from_scratch_training(
77-
model=model,
78-
data=audio,
79-
data_spatial_dim=audio_spatial_dim,
80-
targets=targets,
81-
targets_spatial_dim=targets_spatial_dim,
82-
)
81+
pt_model.eval()
82+
with torch.no_grad():
83+
# TODO how to setup hooks?
84+
from_scratch_training(
85+
model=model,
86+
data=audio,
87+
data_spatial_dim=audio_spatial_dim,
88+
targets=targets,
89+
targets_spatial_dim=targets_spatial_dim,
90+
)
8391
print("One train forward step, duration:", util.hms_fraction(time.time() - start_time), "sec")
92+
print(f"GPU peak memory allocated: {util.human_bytes_size(torch.cuda.max_memory_allocated(dev))}")
8493

8594
for name, loss in rf.get_run_ctx().losses.items():
8695
print(f"Loss {name}: {loss.get_mean_loss().raw_tensor.item()}")

0 commit comments

Comments
 (0)