Skip to content

Commit 69e3a29

Browse files
committed
fix ModuleList ModuleDict and paddle dtype
1 parent 2625104 commit 69e3a29

File tree

5 files changed

+60
-37
lines changed

5 files changed

+60
-37
lines changed

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,49 @@
22
# -*- coding: utf-8 -*-
33

44
from __future__ import absolute_import, division, print_function
5+
6+
import paddle
57
import paddle as pd
68
import paddle.nn as nn
79
import numpy as np
810
import paddle.nn.functional as F
911
from .paddle_nn import nchw_to_nhwc, nhwc_to_nchw, preprocess_2d_format, preprocess_1d_format, preprocess_3d_format
1012
import random
1113

12-
_dtypeDict = [
13-
"float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool",
14-
"complex64", "complex128"
15-
]
14+
15+
_dtypeDict = {
16+
'DType': paddle.dtype,
17+
'float16': paddle.float16,
18+
'float32': paddle.float32,
19+
'float64': paddle.float64,
20+
'int8': paddle.int8,
21+
'int16': paddle.int16,
22+
'int32': paddle.int32,
23+
'int64': paddle.int64,
24+
'uint8': paddle.uint8,
25+
'uint16': None,
26+
'uint32': None,
27+
'uint64': None,
28+
'bool': paddle.bool,
29+
'complex64': paddle.complex64,
30+
'complex128': paddle.complex128
31+
}
1632
# TODO NotImplemented
17-
DType = None
18-
float16 = "float16"
19-
float32 = "float32"
20-
float64 = "float64"
21-
int8 = "int8"
22-
int16 = "int16"
23-
int32 = "int32"
24-
int64 = "int64"
25-
uint8 = "uint8"
26-
uint16 = "uint16"
27-
uint32 = "uint32"
28-
uint64 = "uint64"
29-
bool = "bool"
30-
complex64 = "complex64"
31-
complex128 = "complex128"
33+
DType = paddle.dtype
34+
float16 = paddle.float16
35+
float32 = paddle.float32
36+
float64 = paddle.float64
37+
int8 = paddle.int8
38+
int16 = paddle.int16
39+
int32 = paddle.int32
40+
int64 = paddle.int64
41+
uint8 = paddle.uint8
42+
uint16 = None
43+
uint32 = None
44+
uint64 = None
45+
bool = paddle.bool
46+
complex64 = paddle.complex64
47+
complex128 = paddle.complex128
3248

3349

3450
def _getter(init_fn, **kwargs):

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,10 @@ class ModuleList(Module):
477477
>>> layer_list.append(d3)
478478
"""
479479

480-
def __init__(self, args):
481-
Module.__init__(self)
482-
self.extend(args)
480+
def __init__(self, modules = None):
481+
super(ModuleList, self).__init__()
482+
if modules is not None:
483+
self.extend(modules)
483484

484485
def __getitem__(self, index):
485486
if isinstance(index, slice):
@@ -554,9 +555,10 @@ def forward(self, *inputs):
554555

555556
class ModuleDict(Module):
556557

557-
def __init__(self, modules):
558+
def __init__(self, modules = None):
558559
super(ModuleDict, self).__init__()
559-
self.update(modules)
560+
if modules is not None:
561+
self.update(modules)
560562

561563
def __getitem__(self, key):
562564

tensorlayerx/nn/core/core_paddle.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ def forward(self, input_data):
416416

417417
class ModuleList(Module):
418418

419-
def __init__(self, args):
419+
def __init__(self, modules = None):
420420
super(ModuleList, self).__init__()
421-
if args is not None:
422-
self.extend(args)
421+
if modules is not None:
422+
self.extend(modules)
423423

424424
def __getitem__(self, index):
425425
if isinstance(index, slice):
@@ -501,9 +501,10 @@ def forward(self, *inputs):
501501

502502
class ModuleDict(Module):
503503

504-
def __init__(self, modules):
504+
def __init__(self, modules = None):
505505
super(ModuleDict, self).__init__()
506-
self.update(modules)
506+
if modules is not None:
507+
self.update(modules)
507508

508509
def __getitem__(self, key):
509510

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,9 +759,10 @@ class ModuleList(Module):
759759
>>> layer_list.append(d3)
760760
"""
761761

762-
def __init__(self, args):
762+
def __init__(self, modules = None):
763763
super(ModuleList, self).__init__()
764-
self.extend(args)
764+
if modules is not None:
765+
self.extend(modules)
765766

766767
def __getitem__(self, index):
767768
if isinstance(index, slice):
@@ -890,9 +891,10 @@ class ModuleDict(Module):
890891
891892
"""
892893

893-
def __init__(self, modules):
894+
def __init__(self, modules = None):
894895
super(ModuleDict, self).__init__()
895-
self.update(modules)
896+
if modules is not None:
897+
self.update(modules)
896898

897899
def __getitem__(self, key):
898900

tensorlayerx/nn/core/core_torch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,10 @@ class ModuleList(Module):
375375
>>> layer_list.append(d3)
376376
"""
377377

378-
def __init__(self, args):
378+
def __init__(self, modules = None):
379379
super(ModuleList, self).__init__()
380-
self.extend(args)
380+
if modules is not None:
381+
self.extend(modules)
381382

382383
def __getitem__(self, index):
383384
if isinstance(index, slice):
@@ -459,9 +460,10 @@ def forward(self, *inputs):
459460

460461
class ModuleDict(Module):
461462

462-
def __init__(self, modules):
463+
def __init__(self, modules = None):
463464
super(ModuleDict, self).__init__()
464-
self.update(modules)
465+
if modules is not None:
466+
self.update(modules)
465467

466468
def __getitem__(self, key):
467469

0 commit comments

Comments
 (0)