@@ -39,22 +39,20 @@ def forward(
3939 query : torch .Tensor ,
4040 weight_feedback : torch .Tensor ,
4141 enc_seq_len : torch .Tensor ,
42- device : str ,
4342 ) -> Tuple [torch .Tensor , torch .Tensor ]:
4443 """
4544 :param key: encoder keys of shape [B,T,D_k]
4645 :param value: encoder values of shape [B,T,D_v]
4746 :param query: query of shape [B,D_k]
4847 :param weight_feedback: shape is [B,T,D_k]
4948 :param enc_seq_len: encoder sequence lengths [B]
50- :param device: device where to run the model (cpu or cuda)
5149 :return: attention context [B,D_v], attention weights [B,T,1]
5250 """
5351 # all inputs are already projected
5452 energies = self .linear (nn .functional .tanh (key + query .unsqueeze (1 ) + weight_feedback )) # [B,T,1]
55- time_arange = torch .arange (energies .size (1 ), device = device ) # [T]
53+ time_arange = torch .arange (energies .size (1 ), device = energies . device ) # [T]
5654 seq_len_mask = torch .less (time_arange [None , :], enc_seq_len [:, None ]) # [B,T]
57- energies = torch .where (seq_len_mask .unsqueeze (2 ), energies , torch . tensor (- float ("inf" )))
55+ energies = torch .where (seq_len_mask .unsqueeze (2 ), energies , energies . new_tensor (- float ("inf" )))
5856 weights = nn .functional .softmax (energies , dim = 1 ) # [B,T,1]
5957 weights = self .att_weights_drop (weights )
6058 context = torch .bmm (weights .transpose (1 , 2 ), value ) # [B,1,D_v]
@@ -76,7 +74,6 @@ class AttentionLSTMDecoderV1Config:
7674 attention_cfg: attention config
7775 output_proj_dim: output projection dimension
7876 output_dropout: output dropout
79- device: device where to run the model (cpu or cuda)
8077 """
8178
8279 encoder_dim : int
@@ -89,7 +86,6 @@ class AttentionLSTMDecoderV1Config:
8986 attention_cfg : AdditiveAttentionConfig
9087 output_proj_dim : int
9188 output_dropout : float
92- device : str
9389
9490
9591class AttentionLSTMDecoderV1 (nn .Module ):
@@ -100,6 +96,7 @@ class AttentionLSTMDecoderV1(nn.Module):
10096 def __init__ (self , cfg : AttentionLSTMDecoderV1Config ):
10197 super ().__init__ ()
10298
99+ print (cfg .vocab_size )
103100 self .target_embed = nn .Embedding (num_embeddings = cfg .vocab_size , embedding_dim = cfg .target_embed_dim )
104101 self .target_embed_dropout = nn .Dropout (cfg .target_embed_dropout )
105102
@@ -130,10 +127,6 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
130127 self .output = nn .Linear (cfg .output_proj_dim // 2 , cfg .vocab_size )
131128 self .output_dropout = nn .Dropout (cfg .output_dropout )
132129
133- if "cuda" in cfg .device :
134- assert torch .cuda .is_available (), "CUDA is not available"
135- self .device = cfg .device
136-
137130 def forward (
138131 self ,
139132 encoder_outputs : torch .Tensor ,
@@ -148,10 +141,10 @@ def forward(
148141 :param state: decoder state
149142 """
150143 if state is None :
151- zeros = torch . zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ), device = self . device )
144+ zeros = encoder_outputs . new_zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ))
152145 lstm_state = (zeros , zeros )
153- att_context = torch . zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )), device = self . device )
154- accum_att_weights = torch . zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ), device = self . device )
146+ att_context = encoder_outputs . new_zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )))
147+ accum_att_weights = encoder_outputs . new_zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ))
155148 else :
156149 lstm_state , att_context , accum_att_weights = state
157150
@@ -187,7 +180,6 @@ def forward(
187180 query = s_transformed ,
188181 weight_feedback = weight_feedback ,
189182 enc_seq_len = enc_seq_len ,
190- device = self .device ,
191183 )
192184 att_context_list .append (att_context )
193185 accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5
0 commit comments