@@ -3890,4 +3890,75 @@ def set_device(device = 'GPU', id = 0):
38903890 tf .config .experimental .set_memory_growth (gpu , True )
38913891 tf .config .experimental .set_visible_devices (gpus [id ], 'GPU' )
38923892 except RuntimeError as e :
3893- print (e )
3893+ print (e )
3894+
3895+ def scatter_update (tensor , indices , updates ):
3896+ """Applies sparse updates to a variable
3897+
3898+ Parameters
3899+ ----------
3900+ tensor : Tensor
3901+ A Tensor. The dim of tensor must be 1.
3902+ indices : Tensor
3903+ Indices into the tensor.
3904+ updates : Tensor
3905+ Updated values
3906+
3907+ Returns
3908+ -------
3909+ Tensor after updated.
3910+
3911+ Examples
3912+ ---------
3913+ >>> import tensorlayerx as tlx
3914+ >>> x = tlx.ops.ones((5,))
3915+ >>> indices = tlx.ops.convert_to_tensor([0, 4, 2])
3916+ >>> updates = tlx.ops.convert_to_tensor([1., 4., 7.])
3917+ >>> res = tlx.ops.scatter_update(x, indices, updates)
3918+ >>> [1. 1. 7. 1. 4.]
3919+ """
3920+ shape = indices .shape
3921+ indices = tf .reshape (indices , shape = (shape [0 ], 1 ))
3922+ return tf .tensor_scatter_nd_update (tensor , indices , updates )
3923+
3924+ def get_device ():
3925+ """This function can get the specified global device.
3926+
3927+ Returns
3928+ -------
3929+ The global device.
3930+
3931+ Examples
3932+ ---------
3933+ >>> import tensorlayerx as tlx
3934+ >>> x = tlx.ops.get_device()
3935+ >>> "CPU"
3936+ """
3937+ device = tf .config .experimental .get_visible_devices ('GPU' )
3938+ if len (device ) == 0 :
3939+ device = tf .config .experimental .get_visible_devices ('CPU' )
3940+ return device
3941+
3942+ def to_device (tensor , device = 'GPU' , id = 0 ):
3943+ """Returns a copy of Tensor in specified device.
3944+
3945+ Parameters
3946+ ----------
3947+ tensor : Tensor
3948+ A tensor.
3949+ device : str
3950+ The specified device. Support 'GPU' and 'CPU'. Default is 'GPU'.
3951+ id : int
3952+ The id of specified device. Default is 0.
3953+
3954+
3955+ Examples
3956+ ---------
3957+ >>> import tensorlayerx as tlx
3958+ >>> x = tlx.ops.ones((5,))
3959+ >>> x = tlx.ops.to_device(x, device="GPU", id=0)
3960+ """
3961+ if device is None :
3962+ return tensor
3963+ with tf .device ("/" + device .upper ()+ ':' + str (id )):
3964+ return tf .identity (tensor )
0 commit comments