Skip to content

Commit 44b9cdc

Browse files
committed
remove device from attention test
1 parent 315b579 commit 44b9cdc

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tests/test_enc_dec_att.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_additive_attention():
1616

1717
# pass key as weight feedback just for testing
1818
context, weights = att(
19-
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu"
19+
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len
2020
)
2121
assert context.shape == (10, 5)
2222
assert weights.shape == (10, 20, 1)
@@ -42,7 +42,6 @@ def test_encoder_decoder_attention_model():
4242
output_dropout=0.1,
4343
zoneout_drop_c=0.0,
4444
zoneout_drop_h=0.0,
45-
device="cpu",
4645
)
4746
decoder = AttentionLSTMDecoderV1(decoder_cfg)
4847
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
@@ -69,7 +68,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
6968
output_dropout=0.1,
7069
zoneout_drop_c=zoneout_drop_c,
7170
zoneout_drop_h=zoneout_drop_h,
72-
device="cpu",
7371
)
7472
decoder = AttentionLSTMDecoderV1(decoder_cfg)
7573
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)

0 commit comments

Comments
 (0)