@@ -771,7 +771,7 @@ def __init__(self):
771771 pass
772772
773773 def __call__ (self , input , multiples ):
774- raise NotImplementedError
774+ return torch . tile ( input , dims = multiples )
775775
776776
777777def tile (input , multiples ):
@@ -791,19 +791,19 @@ def tile(input, multiples):
791791 A Tensor. Has the same type as input.
792792 """
793793
794- raise NotImplementedError
794+ return torch . tile ( input , multiples )
795795
796796
797797class Cast (object ):
798798
799- def __init__ (self , dtype ):
799+ def __init__ (self , dtype = None ):
800800 self .dtype = dtype
801801
802802 def __call__ (self , x ):
803- raise NotImplementedError
803+ return x . type ( self . dtype )
804804
805805
806- def cast (x , dtype ):
806+ def cast (x , dtype = None ):
807807 """
808808 Casts a tensor to a new type.
809809
@@ -820,7 +820,7 @@ def cast(x, dtype):
820820 A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype.
821821 """
822822
823- raise NotImplementedError
823+ return x . type ( dtype )
824824
825825
826826class Transpose (object ):
@@ -830,7 +830,7 @@ def __init__(self, perm, conjugate=False):
830830 self .conjugate = conjugate
831831
832832 def __call__ (self , a ):
833- raise NotImplementedError
833+ return transpose ( a , self . perm , self . conjugate )
834834
835835
836836def transpose (a , perm = None , conjugate = False ):
@@ -850,8 +850,19 @@ def transpose(a, perm=None, conjugate=False):
850850 -------
851851 A transposed Tensor.
852852 """
853-
854- raise NotImplementedError
853+ if perm == None :
854+ if len (a .shape ) <= 2 :
855+ return torch .t (a )
856+ if len (a .shape ) == 3 :
857+ perm = [2 , 1 , 0 ]
858+ if len (a .shape ) == 4 :
859+ perm = [3 , 2 , 1 , 0 ]
860+ if len (a .shape ) == 5 :
861+ perm = [4 , 3 , 2 , 1 , 0 ]
862+ out = torch .permute (a , perm )
863+ if conjugate :
864+ out = torch .conj_physical (out )
865+ return out
855866
856867
857868def gather_nd (params , indices , batch_dims = 0 ):
@@ -872,7 +883,18 @@ def gather_nd(params, indices, batch_dims=0):
872883 A Tensor. Has the same type as params.
873884 """
874885
875- raise NotImplementedError
886+ out_shape = indices .shape [:- 1 ]
887+ indices = indices .unsqueeze (0 ).transpose (0 , - 1 )
888+ ndim = indices .shape [0 ]
889+ indices = indices .long ()
890+ idx = torch .zeros_like (indices [0 ], device = indices .device ).long ()
891+ m = 1
892+
893+ for i in range (ndim )[::- 1 ]:
894+ idx += indices [i ] * m
895+ m *= params .size (i )
896+ out = torch .take (params , idx )
897+ return out .view (out_shape )
876898
877899
878900def clip_by_value (t , clip_value_min , clip_value_max ):
@@ -893,10 +915,15 @@ def clip_by_value(t, clip_value_min, clip_value_max):
893915 A clipped Tensor or IndexedSlices.
894916 """
895917
896- raise NotImplementedError
918+ t_min = clip_value_min
919+ t_max = clip_value_max
897920
921+ result = (t >= t_min ) * t + (t < t_min ) * t_min
922+ result = (result <= t_max ) * result + (result > t_max ) * t_max
923+ return result
898924
899- def split (value , num_or_size_splits , axis = 0 , num = None ):
925+
926+ def split (value , num_or_size_splits , axis = 0 ):
900927 """
901928 Splits a tensor into sub tensors.
902929
@@ -917,46 +944,58 @@ def split(value, num_or_size_splits, axis=0, num=None):
917944 Tensor objects resulting from splitting value.
918945 """
919946
920- raise NotImplementedError
947+ return torch . split ( value , num_or_size_splits , dim = axis )
921948
922949
923950class Floor (object ):
924951
925952 def __call__ (self , x ):
926- raise NotImplementedError
953+ return torch . floor ( x )
927954
928955
929956def floor (x ):
930- raise NotImplementedError
957+ return torch . floor ( x )
931958
932959
933960def gather (params , indices ):
934- raise NotImplementedError
961+ return gather_nd ( params , indices )
935962
936963
937964def linspace (start , stop , num ):
938- raise NotImplementedError
965+ return torch . linspace ( start = start , end = stop , steps = num )
939966
940967
941968def slice (inputs , starts , sizes ):
942969 raise NotImplementedError
943970
944971
945972def add_n (inputs ):
946- raise NotImplementedError
973+ a = inputs [0 ]
974+ for b in inputs [1 :]:
975+ a += b
976+ return a
977+
947978
948979
949980class OneHot (object ):
950981
951- def __init__ (self , depth , on_value , off_value , axis , dtype ):
982+ def __init__ (self , depth = - 1 , on_value = None , off_value = None , axis = None , dtype = None ):
952983 self .depth = depth
953984 self .on_value = on_value
954985 self .off_value = off_value
955986 self .axis = axis
956987 self .dtype = dtype
957988
958989 def __call__ (self , inputs ):
959- raise NotImplementedError
990+ if [self .on_value , self .off_value ] == [None , None ]:
991+ return torch .nn .functional .one_hot (inputs , self .depth )
992+ else :
993+ out = torch .nn .functional .one_hot (inputs , self .depth )
994+ out = cast (out , torch .float64 )
995+ out = torch .where (out == 1 , self .on_value , out )
996+ out = torch .where (out == 0 , self .off_value , out )
997+ out = cast (out , torch .int )
998+ return out
960999
9611000
9621001class L2Normalize (object ):
0 commit comments