Skip to content

Commit 3305385

Browse files
committed
fix torch init constant
1 parent de7d9dc commit 3305385

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorlayerx/nn/initializers/torch_initializers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import tensorlayerx as tlx
6+
import numpy as np
67

78
__all__ = [
89
'Initializer',
@@ -123,7 +124,11 @@ def __init__(self, value=0):
123124

124125
def __call__(self, shape, dtype=tlx.float32):
125126
_tensor = torch.empty(size=shape, dtype=dtype)
126-
return torch.nn.init.constant_(_tensor, val=self.value)
127+
if isinstance(self.value, (int, float)):
128+
return torch.nn.init.constant_(_tensor, val=self.value)
129+
elif isinstance(self.value, (torch.Tensor, list, np.ndarray)):
130+
_tensor.data = torch.as_tensor(self.value)
131+
return _tensor
127132

128133
def get_config(self):
129134
return {"value": self.value}

0 commit comments

Comments
 (0)