@@ -83,7 +83,7 @@ def get_tensor_shape(x):
8383
8484
8585# initializers
86- def zeros (shape , dtype = 'float32' ):
86+ def zeros (shape , dtype = 'float32' , device = None ):
8787 """
8888 Creates a tensor with all elements set to zero.
8989
@@ -93,6 +93,8 @@ def zeros(shape, dtype='float32'):
9393 a tuple of integers, or a 1-D Tensor of type int32.
9494 dtype : tensor or str
9595 The DType of an element in the resulting Tensor
96+ device : str or None
97+ create a tensor on 'cpu' or 'gpu', defautl is None.
9698
9799 Returns
98100 -------
@@ -109,7 +111,7 @@ def zeros(shape, dtype='float32'):
109111 return tf .zeros (shape = shape , dtype = dtype_str (dtype ))
110112
111113
112- def ones (shape , dtype = 'float32' ):
114+ def ones (shape , dtype = 'float32' , device = None ):
113115 """
114116 Creates a tensor with all elements set to ones.
115117
@@ -119,6 +121,8 @@ def ones(shape, dtype='float32'):
119121 a tuple of integers, or a 1-D Tensor of type int32.
120122 dtype : tensor or str
121123 The DType of an element in the resulting Tensor
124+ device : str or None
125+ create a tensor on 'cpu' or 'gpu', defautl is None.
122126
123127 Returns
124128 -------
@@ -135,7 +139,7 @@ def ones(shape, dtype='float32'):
135139 return tf .ones (shape = shape , dtype = dtype_str (dtype ))
136140
137141
138- def constant (value , dtype = 'float32' , shape = None ):
142+ def constant (value , dtype = 'float32' , shape = None , device = None ):
139143 """
140144 Creates a constant tensor from a tensor-like object.
141145
@@ -147,6 +151,8 @@ def constant(value, dtype='float32', shape=None):
147151 The type of the elements of the resulting tensor.
148152 shape : tuple
149153 Optional dimensions of resulting tensor.
154+ device : str or None
155+ create a tensor on 'cpu' or 'gpu', defautl is None.
150156
151157 Returns
152158 -------
@@ -345,7 +351,7 @@ def xavier_uniform(shape, dtype='float32', seed=None):
345351 return tf .initializers .glorot_uniform (seed )(shape = shape , dtype = dtype_str (dtype ))
346352
347353
348- def Variable (initial_value , name , trainable = True ):
354+ def Variable (initial_value , name , trainable = True , device = None ):
349355 """
350356 Creates a new variable with value initial_value.
351357
@@ -355,6 +361,8 @@ def Variable(initial_value, name, trainable=True):
355361 A Tensor, or Python object convertible to a Tensor
356362 name : str
357363 Optional name for the variable. Defaults to 'Variable' and gets uniquified automatically.
364+ device : str or None
365+ create a tensor on 'cpu' or 'gpu', defautl is None.
358366 Returns
359367 -------
360368 Variable
@@ -591,7 +599,7 @@ def concat(values, axis):
591599 return tf .concat (values , axis )
592600
593601
594- def convert_to_tensor (value , dtype = None ):
602+ def convert_to_tensor (value , dtype = None , device = None ):
595603 """
596604 Converts the given value to a Tensor.
597605
@@ -601,6 +609,8 @@ def convert_to_tensor(value, dtype=None):
601609 An object whose type has a registered Tensor conversion function.
602610 dtype : optional
603611 Optional element type for the returned tensor. If missing, the type is inferred from the type of value.
612+ device : str or None
613+ create a tensor on 'cpu' or 'gpu', defautl is None.
604614
605615 Returns
606616 -------
0 commit comments