1212import  contextlib 
1313import  copy 
1414import  re 
15- from  typing  import  Callable , Dict , Optional , Tuple ,  Type , Union 
15+ from  typing  import  Callable , Dict , Optional , Type , Union 
1616
1717import  torch 
1818
@@ -118,7 +118,7 @@ def default_convert(data):
118118def  collate (
119119    batch ,
120120    * ,
121-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
121+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
122122):
123123    r""" 
124124    General collate function that handles collection type of element within each batch. 
@@ -243,7 +243,7 @@ def collate(
243243def  collate_tensor_fn (
244244    batch ,
245245    * ,
246-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
246+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
247247):
248248    elem  =  batch [0 ]
249249    out  =  None 
@@ -275,7 +275,7 @@ def collate_tensor_fn(
275275def  collate_numpy_array_fn (
276276    batch ,
277277    * ,
278-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
278+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
279279):
280280    elem  =  batch [0 ]
281281    # array of string classes and object 
@@ -288,36 +288,36 @@ def collate_numpy_array_fn(
288288def  collate_numpy_scalar_fn (
289289    batch ,
290290    * ,
291-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
291+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
292292):
293293    return  torch .as_tensor (batch )
294294
295295
296296def  collate_float_fn (
297297    batch ,
298298    * ,
299-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
299+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
300300):
301301    return  torch .tensor (batch , dtype = torch .float64 )
302302
303303
304304def  collate_int_fn (
305305    batch ,
306306    * ,
307-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
307+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
308308):
309309    return  torch .tensor (batch )
310310
311311
312312def  collate_str_fn (
313313    batch ,
314314    * ,
315-     collate_fn_map : Optional [Dict [Union [Type , Tuple [Type , ...]], Callable ]] =  None ,
315+     collate_fn_map : Optional [Dict [Union [Type , tuple [Type , ...]], Callable ]] =  None ,
316316):
317317    return  batch 
318318
319319
320- default_collate_fn_map : Dict [Union [Type , Tuple [Type , ...]], Callable ] =  {
320+ default_collate_fn_map : Dict [Union [Type , tuple [Type , ...]], Callable ] =  {
321321    torch .Tensor : collate_tensor_fn 
322322}
323323with  contextlib .suppress (ImportError ):
0 commit comments