Skip to content

Commit 4ddfae6

Browse files
committed
check for cuda availability
1 parent 8030a30 commit 4ddfae6

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

i6_models/decoder/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
130130
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
131131
self.output_dropout = nn.Dropout(cfg.output_dropout)
132132

133+
if "cuda" in cfg.device:
134+
assert torch.cuda.is_available(), "CUDA is not available"
133135
self.device = cfg.device
134136

135137
def forward(

0 commit comments

Comments
 (0)