@@ -39,18 +39,20 @@ def forward(
3939 query : torch .Tensor ,
4040 weight_feedback : torch .Tensor ,
4141 enc_seq_len : torch .Tensor ,
42+ device : str ,
4243 ) -> Tuple [torch .Tensor , torch .Tensor ]:
4344 """
4445 :param key: encoder keys of shape [B,T,D_k]
4546 :param value: encoder values of shape [B,T,D_v]
4647 :param query: query of shape [B,D_k]
4748 :param weight_feedback: shape is [B,T,D_k]
4849 :param enc_seq_len: encoder sequence lengths [B]
50+ :param device: device where to run the model (cpu or cuda)
4951 :return: attention context [B,D_v], attention weights [B,T,1]
5052 """
5153 # all inputs are already projected
5254 energies = self .linear (nn .functional .tanh (key + query .unsqueeze (1 ) + weight_feedback )) # [B,T,1]
53- time_arange = torch .arange (energies .size (1 ), device = "cuda" ) # [T]
55+ time_arange = torch .arange (energies .size (1 ), device = device ) # [T]
5456 seq_len_mask = torch .less (time_arange [None , :], enc_seq_len [:, None ]) # [B,T]
5557 energies = torch .where (seq_len_mask .unsqueeze (2 ), energies , torch .tensor (- float ("inf" )))
5658 weights = nn .functional .softmax (energies , dim = 1 ) # [B,T,1]
@@ -74,6 +76,7 @@ class AttentionLSTMDecoderV1Config:
7476 attention_cfg: attention config
7577 output_proj_dim: output projection dimension
7678 output_dropout: output dropout
79+ device: device where to run the model (cpu or cuda)
7780 """
7881
7982 encoder_dim : int
@@ -86,6 +89,7 @@ class AttentionLSTMDecoderV1Config:
8689 attention_cfg : AdditiveAttentionConfig
8790 output_proj_dim : int
8891 output_dropout : float
92+ device : str
8993
9094
9195class AttentionLSTMDecoderV1 (nn .Module ):
@@ -126,6 +130,8 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
126130 self .output = nn .Linear (cfg .output_proj_dim // 2 , cfg .vocab_size )
127131 self .output_dropout = nn .Dropout (cfg .output_dropout )
128132
133+ self .device = cfg .device
134+
129135 def forward (
130136 self ,
131137 encoder_outputs : torch .Tensor ,
@@ -140,10 +146,10 @@ def forward(
140146 :param state: decoder state
141147 """
142148 if state is None :
143- zeros = torch .zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ), device = "cuda" )
149+ zeros = torch .zeros ((encoder_outputs .size (0 ), self .lstm_hidden_size ), device = self . device )
144150 lstm_state = (zeros , zeros )
145- att_context = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )), device = "cuda" )
146- accum_att_weights = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ), device = "cuda" )
151+ att_context = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (2 )), device = self . device )
152+ accum_att_weights = torch .zeros ((encoder_outputs .size (0 ), encoder_outputs .size (1 ), 1 ), device = self . device )
147153 else :
148154 lstm_state , att_context , accum_att_weights = state
149155
@@ -179,6 +185,7 @@ def forward(
179185 query = s_transformed ,
180186 weight_feedback = weight_feedback ,
181187 enc_seq_len = enc_seq_len ,
188+ device = self .device ,
182189 )
183190 att_context_list .append (att_context )
184191 accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5
0 commit comments