11import numbers
22import warnings
33from collections import namedtuple
4- from typing import List , Tuple
4+ from typing import List
55
66import torch
77import torch .jit as jit
@@ -131,8 +131,8 @@ def __init__(self, input_size, hidden_size):
131131
132132 @jit .script_method
133133 def forward (
134- self , input : Tensor , state : Tuple [Tensor , Tensor ]
135- ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
134+ self , input : Tensor , state : tuple [Tensor , Tensor ]
135+ ) -> tuple [Tensor , tuple [Tensor , Tensor ]]:
136136 hx , cx = state
137137 gates = (
138138 torch .mm (input , self .weight_ih .t ())
@@ -199,8 +199,8 @@ def __init__(self, input_size, hidden_size, decompose_layernorm=False):
199199
200200 @jit .script_method
201201 def forward (
202- self , input : Tensor , state : Tuple [Tensor , Tensor ]
203- ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
202+ self , input : Tensor , state : tuple [Tensor , Tensor ]
203+ ) -> tuple [Tensor , tuple [Tensor , Tensor ]]:
204204 hx , cx = state
205205 igates = self .layernorm_i (torch .mm (input , self .weight_ih .t ()))
206206 hgates = self .layernorm_h (torch .mm (hx , self .weight_hh .t ()))
@@ -225,8 +225,8 @@ def __init__(self, cell, *cell_args):
225225
226226 @jit .script_method
227227 def forward (
228- self , input : Tensor , state : Tuple [Tensor , Tensor ]
229- ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
228+ self , input : Tensor , state : tuple [Tensor , Tensor ]
229+ ) -> tuple [Tensor , tuple [Tensor , Tensor ]]:
230230 inputs = input .unbind (0 )
231231 outputs = torch .jit .annotate (List [Tensor ], [])
232232 for i in range (len (inputs )):
@@ -242,8 +242,8 @@ def __init__(self, cell, *cell_args):
242242
243243 @jit .script_method
244244 def forward (
245- self , input : Tensor , state : Tuple [Tensor , Tensor ]
246- ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
245+ self , input : Tensor , state : tuple [Tensor , Tensor ]
246+ ) -> tuple [Tensor , tuple [Tensor , Tensor ]]:
247247 inputs = reverse (input .unbind (0 ))
248248 outputs = jit .annotate (List [Tensor ], [])
249249 for i in range (len (inputs )):
@@ -266,11 +266,11 @@ def __init__(self, cell, *cell_args):
266266
267267 @jit .script_method
268268 def forward (
269- self , input : Tensor , states : List [Tuple [Tensor , Tensor ]]
270- ) -> Tuple [Tensor , List [Tuple [Tensor , Tensor ]]]:
269+ self , input : Tensor , states : List [tuple [Tensor , Tensor ]]
270+ ) -> tuple [Tensor , List [tuple [Tensor , Tensor ]]]:
271271 # List[LSTMState]: [forward LSTMState, backward LSTMState]
272272 outputs = jit .annotate (List [Tensor ], [])
273- output_states = jit .annotate (List [Tuple [Tensor , Tensor ]], [])
273+ output_states = jit .annotate (List [tuple [Tensor , Tensor ]], [])
274274 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
275275 i = 0
276276 for direction in self .directions :
@@ -300,10 +300,10 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
300300
301301 @jit .script_method
302302 def forward (
303- self , input : Tensor , states : List [Tuple [Tensor , Tensor ]]
304- ) -> Tuple [Tensor , List [Tuple [Tensor , Tensor ]]]:
303+ self , input : Tensor , states : List [tuple [Tensor , Tensor ]]
304+ ) -> tuple [Tensor , List [tuple [Tensor , Tensor ]]]:
305305 # List[LSTMState]: One state per layer
306- output_states = jit .annotate (List [Tuple [Tensor , Tensor ]], [])
306+ output_states = jit .annotate (List [tuple [Tensor , Tensor ]], [])
307307 output = input
308308 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
309309 i = 0
@@ -330,11 +330,11 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
330330
331331 @jit .script_method
332332 def forward (
333- self , input : Tensor , states : List [List [Tuple [Tensor , Tensor ]]]
334- ) -> Tuple [Tensor , List [List [Tuple [Tensor , Tensor ]]]]:
333+ self , input : Tensor , states : List [List [tuple [Tensor , Tensor ]]]
334+ ) -> tuple [Tensor , List [List [tuple [Tensor , Tensor ]]]]:
335335 # List[List[LSTMState]]: The outer list is for layers,
336336 # inner list is for directions.
337- output_states = jit .annotate (List [List [Tuple [Tensor , Tensor ]]], [])
337+ output_states = jit .annotate (List [List [tuple [Tensor , Tensor ]]], [])
338338 output = input
339339 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
340340 i = 0
@@ -370,10 +370,10 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
370370
371371 @jit .script_method
372372 def forward (
373- self , input : Tensor , states : List [Tuple [Tensor , Tensor ]]
374- ) -> Tuple [Tensor , List [Tuple [Tensor , Tensor ]]]:
373+ self , input : Tensor , states : List [tuple [Tensor , Tensor ]]
374+ ) -> tuple [Tensor , List [tuple [Tensor , Tensor ]]]:
375375 # List[LSTMState]: One state per layer
376- output_states = jit .annotate (List [Tuple [Tensor , Tensor ]], [])
376+ output_states = jit .annotate (List [tuple [Tensor , Tensor ]], [])
377377 output = input
378378 # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
379379 i = 0
0 commit comments