Skip to content

Commit 7fc6650

Browse files
committed
Update torch backend
1 parent 9b7b27b commit 7fc6650

File tree

4 files changed

+64
-25
lines changed

4 files changed

+64
-25
lines changed

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def clip_by_value(t, clip_value_min, clip_value_max):
10901090
return output
10911091

10921092

1093-
def split(value, num_or_size_splits, axis=0, num=None):
1093+
def split(value, num_or_size_splits, axis=0):
10941094
"""
10951095
Splits a tensor into sub tensors.
10961096

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def clip_by_value(t, clip_value_min, clip_value_max):
869869
return pd.clip(t, clip_value_min, clip_value_max)
870870

871871

872-
def split(value, num_or_size_splits, axis=0, num=None):
872+
def split(value, num_or_size_splits, axis=0):
873873
"""
874874
Splits a tensor into sub tensors.
875875

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def split(value, num_or_size_splits, axis=0):
11331133
----------
11341134
value : tensor
11351135
The Tensor to split.
1136-
num_or_size_splits : list
1136+
num_or_size_splits : int or list
11371137
Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or
11381138
Python list containing the sizes of each output tensor along split_dim.
11391139
axis : int
@@ -1153,7 +1153,7 @@ def split(value, num_or_size_splits, axis=0):
11531153
11541154
"""
11551155

1156-
return tf.split(value=value, num_or_size_splits=num_or_size_splits, axis=axis, num=num)
1156+
return tf.split(value=value, num_or_size_splits=num_or_size_splits, axis=axis)
11571157

11581158

11591159
class Floor(object):
@@ -1230,7 +1230,7 @@ def add_n(inputs):
12301230

12311231
class OneHot(object):
12321232

1233-
def __init__(self, depth, on_value, off_value, axis, dtype):
1233+
def __init__(self, depth, on_value=None, off_value=None, axis=None, dtype=None):
12341234
self.depth = depth
12351235
self.on_value = on_value
12361236
self.off_value = off_value

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def __init__(self):
771771
pass
772772

773773
def __call__(self, input, multiples):
774-
raise NotImplementedError
774+
return torch.tile(input, dims=multiples)
775775

776776

777777
def tile(input, multiples):
@@ -791,19 +791,19 @@ def tile(input, multiples):
791791
A Tensor. Has the same type as input.
792792
"""
793793

794-
raise NotImplementedError
794+
return torch.tile(input, multiples)
795795

796796

797797
class Cast(object):
798798

799-
def __init__(self, dtype):
799+
def __init__(self, dtype=None):
800800
self.dtype = dtype
801801

802802
def __call__(self, x):
803-
raise NotImplementedError
803+
return x.type(self.dtype)
804804

805805

806-
def cast(x, dtype):
806+
def cast(x, dtype=None):
807807
"""
808808
Casts a tensor to a new type.
809809
@@ -820,7 +820,7 @@ def cast(x, dtype):
820820
A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype.
821821
"""
822822

823-
raise NotImplementedError
823+
return x.type(dtype)
824824

825825

826826
class Transpose(object):
@@ -830,7 +830,7 @@ def __init__(self, perm, conjugate=False):
830830
self.conjugate = conjugate
831831

832832
def __call__(self, a):
833-
raise NotImplementedError
833+
return transpose(a, self.perm, self.conjugate)
834834

835835

836836
def transpose(a, perm=None, conjugate=False):
@@ -850,8 +850,19 @@ def transpose(a, perm=None, conjugate=False):
850850
-------
851851
A transposed Tensor.
852852
"""
853-
854-
raise NotImplementedError
853+
if perm == None:
854+
if len(a.shape) <= 2:
855+
return torch.t(a)
856+
if len(a.shape) == 3:
857+
perm = [2, 1, 0]
858+
if len(a.shape) == 4:
859+
perm = [3, 2, 1, 0]
860+
if len(a.shape) == 5:
861+
perm = [4, 3, 2, 1, 0]
862+
out = torch.permute(a, perm)
863+
if conjugate:
864+
out = torch.conj_physical(out)
865+
return out
855866

856867

857868
def gather_nd(params, indices, batch_dims=0):
@@ -872,7 +883,18 @@ def gather_nd(params, indices, batch_dims=0):
872883
A Tensor. Has the same type as params.
873884
"""
874885

875-
raise NotImplementedError
886+
out_shape = indices.shape[:-1]
887+
indices = indices.unsqueeze(0).transpose(0, -1)
888+
ndim = indices.shape[0]
889+
indices = indices.long()
890+
idx = torch.zeros_like(indices[0], device=indices.device).long()
891+
m = 1
892+
893+
for i in range(ndim)[::-1]:
894+
idx += indices[i] * m
895+
m *= params.size(i)
896+
out = torch.take(params, idx)
897+
return out.view(out_shape)
876898

877899

878900
def clip_by_value(t, clip_value_min, clip_value_max):
@@ -893,10 +915,15 @@ def clip_by_value(t, clip_value_min, clip_value_max):
893915
A clipped Tensor or IndexedSlices.
894916
"""
895917

896-
raise NotImplementedError
918+
t_min = clip_value_min
919+
t_max = clip_value_max
897920

921+
result = (t >= t_min) * t + (t < t_min) * t_min
922+
result = (result <= t_max) * result + (result > t_max) * t_max
923+
return result
898924

899-
def split(value, num_or_size_splits, axis=0, num=None):
925+
926+
def split(value, num_or_size_splits, axis=0):
900927
"""
901928
Splits a tensor into sub tensors.
902929
@@ -917,46 +944,58 @@ def split(value, num_or_size_splits, axis=0, num=None):
917944
Tensor objects resulting from splitting value.
918945
"""
919946

920-
raise NotImplementedError
947+
return torch.split(value, num_or_size_splits, dim=axis)
921948

922949

923950
class Floor(object):
924951

925952
def __call__(self, x):
926-
raise NotImplementedError
953+
return torch.floor(x)
927954

928955

929956
def floor(x):
930-
raise NotImplementedError
957+
return torch.floor(x)
931958

932959

933960
def gather(params, indices):
934-
raise NotImplementedError
961+
return gather_nd(params, indices)
935962

936963

937964
def linspace(start, stop, num):
938-
raise NotImplementedError
965+
return torch.linspace(start=start, end=stop, steps=num)
939966

940967

941968
def slice(inputs, starts, sizes):
942969
raise NotImplementedError
943970

944971

945972
def add_n(inputs):
946-
raise NotImplementedError
973+
a = inputs[0]
974+
for b in inputs[1:]:
975+
a += b
976+
return a
977+
947978

948979

949980
class OneHot(object):
950981

951-
def __init__(self, depth, on_value, off_value, axis, dtype):
982+
def __init__(self, depth=-1, on_value=None, off_value=None, axis=None, dtype=None):
952983
self.depth = depth
953984
self.on_value = on_value
954985
self.off_value = off_value
955986
self.axis = axis
956987
self.dtype = dtype
957988

958989
def __call__(self, inputs):
959-
raise NotImplementedError
990+
if [self.on_value, self.off_value] == [None, None]:
991+
return torch.nn.functional.one_hot(inputs, self.depth)
992+
else:
993+
out = torch.nn.functional.one_hot(inputs, self.depth)
994+
out = cast(out, torch.float64)
995+
out = torch.where(out==1, self.on_value, out)
996+
out = torch.where(out==0, self.off_value, out)
997+
out = cast(out, torch.int)
998+
return out
960999

9611000

9621001
class L2Normalize(object):

0 commit comments

Comments
 (0)