@@ -15,9 +15,7 @@ def test_additive_attention():
1515 enc_seq_len = torch .arange (start = 10 , end = 20 ) # [10, ..., 19]
1616
1717 # pass key as weight feedback just for testing
18- context , weights = att (
19- key = key , value = value , query = query , weight_feedback = key , enc_seq_len = enc_seq_len , device = "cpu"
20- )
18+ context , weights = att (key = key , value = value , query = query , weight_feedback = key , enc_seq_len = enc_seq_len )
2119 assert context .shape == (10 , 5 )
2220 assert weights .shape == (10 , 20 , 1 )
2321
@@ -42,7 +40,6 @@ def test_encoder_decoder_attention_model():
4240 output_dropout = 0.1 ,
4341 zoneout_drop_c = 0.0 ,
4442 zoneout_drop_h = 0.0 ,
45- device = "cpu" ,
4643 )
4744 decoder = AttentionLSTMDecoderV1 (decoder_cfg )
4845 target_labels = torch .randint (low = 0 , high = 15 , size = (10 , 7 )) # [B,N]
@@ -69,7 +66,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
6966 output_dropout = 0.1 ,
7067 zoneout_drop_c = zoneout_drop_c ,
7168 zoneout_drop_h = zoneout_drop_h ,
72- device = "cpu" ,
7369 )
7470 decoder = AttentionLSTMDecoderV1 (decoder_cfg )
7571 decoder_logits , _ = decoder (encoder_outputs = encoder , labels = target_labels , enc_seq_len = encoder_seq_len )
0 commit comments