Skip to content

Commit 36fbe25

Browse files
committed
cleanup
1 parent f17a881 commit 36fbe25

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

users/zeyer/audio/torch/random_speech_like.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77

8-
from typing import Optional, Any, Dict
8+
from typing import Optional, Union, Any, Dict
99
import torch
1010
from returnn.tensor import Tensor, Dim, batch_dim
1111

@@ -75,7 +75,12 @@ def _integrate_rnd_frequencies(
7575

7676

7777
def generate_dummy_train_input_kwargs(
78-
*, batch_size: int = 20, duration_secs: float = 10.0, sample_rate: int = 16_000, dev: torch.device, target_dim: Dim
78+
*,
79+
batch_size: int = 20,
80+
duration_secs: float = 10.0,
81+
sample_rate: int = 16_000,
82+
dev: Optional[Union[torch.device, str]] = None,
83+
target_dim: Dim,
7984
) -> Dict[str, Any]:
8085
batch_dim.dyn_size_ext = Tensor("batch", [], dtype="int32", raw_tensor=torch.tensor(batch_size, dtype=torch.int32))
8186
num_frames = int(duration_secs * sample_rate)

users/zeyer/experiments/exp2023_04_25_rf/auto_rnd_init.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def test():
3232
from returnn.datasets.util.vocabulary import Vocabulary
3333
import returnn.frontend as rf
3434
from returnn.tensor import Dim
35-
from returnn.torch.frontend.bridge import rf_module_to_pt_module
3635
from i6_experiments.users.zeyer.audio.torch.random_speech_like import generate_dummy_train_input_kwargs
3736

3837
better_exchook.install()
@@ -45,19 +44,17 @@ def test():
4544
model: Model = from_scratch_model_def(epoch=1, in_dim=Dim(1, name="in"), target_dim=target_dim)
4645
print("Num model parameters:", util.human_size(sum(p.num_elements() for p in model.parameters())))
4746

48-
pt_model = rf_module_to_pt_module(model)
4947
dev = torch.device("cuda")
50-
pt_model.to(dev)
5148
rf.set_default_device("cuda")
49+
model.to(device=rf.get_default_device())
5250
print(f"GPU memory usage (allocated model): {util.human_bytes_size(torch.cuda.memory_allocated(dev))}")
5351

54-
train_input_kwargs = generate_dummy_train_input_kwargs(dev=dev, target_dim=target_dim)
52+
train_input_kwargs = generate_dummy_train_input_kwargs(dev=rf.get_default_device(), target_dim=target_dim)
5553

5654
# TODO how to setup hooks?
5755

5856
start_time = time.time()
5957
rf.init_train_step_run_ctx(train_flag=False)
60-
pt_model.eval()
6158
with torch.no_grad():
6259
from_scratch_training(model=model, **train_input_kwargs)
6360
print("One train forward step, duration:", util.hms_fraction(time.time() - start_time), "sec")

0 commit comments

Comments
 (0)