2020
2121#include " ps/base.h"
2222#include " ps/hash_table8.hpp"
23- #include " ps/internal/ backend.h"
23+ #include " ps/backend.h"
2424#include " ps/internal/utils.h"
2525#include " ps/kv_app.h"
2626
@@ -236,15 +236,16 @@ class AFTensorWorker {
236236 void ZPush_ (int ts, const SArray<Key>& keys, const at::Tensor& tensor,
237237 int cmd = 0 ) {
238238 SArray<char > val;
239- val.reset (reinterpret_cast <char *>(tensor.data_ptr ()),
239+ void * mappedPtr = Backend::Get ()->GetAccessibleAddr (tensor);
240+ val.reset (reinterpret_cast <char *>(mappedPtr),
240241 tensor.numel () * tensor.itemsize (), [tensor](void *) {});
241242
242243 Message msg;
243244 msg.meta .request = true ;
244245 msg.meta .head = cmd;
245246 msg.meta .push = true ;
246247 msg.meta .timestamp = ts;
247- msg.meta .addr = reinterpret_cast <uint64_t >(tensor. data_ptr () );
248+ msg.meta .addr = reinterpret_cast <uint64_t >(mappedPtr );
248249 msg.meta .val_len = tensor.numel () * tensor.itemsize ();
249250 PS_VLOG (2 ) << " ZPush_ addr: 0x" << std::hex << msg.meta .addr << std::dec
250251 << " val_len: " << msg.meta .val_len ;
@@ -284,13 +285,14 @@ class AFTensorWorker {
284285
285286 *key.data () = pull_tensors[i * pull_batch_size + index].key ;
286287
287- val.reset (reinterpret_cast <char *>(tensor.data_ptr ()),
288+ void * mappedPtr = Backend::Get ()->GetAccessibleAddr (tensor);
289+ val.reset (reinterpret_cast <char *>(mappedPtr),
288290 tensor.numel () * tensor.itemsize (), [tensor](void *) {});
289291
290292 msg.meta .request = true ;
291293 msg.meta .head = cmd;
292294 msg.meta .push = false ;
293- msg.meta .addr = reinterpret_cast <uint64_t >(tensor. data_ptr () );
295+ msg.meta .addr = reinterpret_cast <uint64_t >(mappedPtr );
294296 msg.meta .val_len = tensor.numel () * tensor.itemsize ();
295297 msg.meta .key = key[0 ];
296298 msg.meta .is_tensor = 1 ;
@@ -483,7 +485,8 @@ class AFTensorServer {
483485 res.keys = key;
484486
485487 SArray<char > tensor_val;
486- tensor_val.reset (reinterpret_cast <char *>(tensors[0 ].val .data_ptr ()),
488+ tensor_val.reset (reinterpret_cast <char *>(
489+ Backend::Get ()->GetAccessibleAddr (tensors[0 ].val )),
487490 tensors[0 ].val .numel () * tensors[0 ].val .itemsize (),
488491 [](void *) {});
489492 res.vals = tensor_val;
@@ -506,7 +509,8 @@ class AFTensorServer {
506509 rsp.kv_pair .keys = key;
507510
508511 rsp.kv_pair .vals .reset (
509- reinterpret_cast <char *>(res_kv.val .data_ptr ()),
512+ reinterpret_cast <char *>(
513+ Backend::Get ()->GetAccessibleAddr (res_kv.val )),
510514 res_kv.val .numel () * res_kv.val .itemsize (), [](void *) {});
511515
512516 rsp.kv_meta = kv_meta;
@@ -558,7 +562,8 @@ class AFTensorServer {
558562 PS_CHECK_GT (worker_ranks.size (), 0 ) << " ranks or keys should not be empty" ;
559563 PS_CHECK_EQ (worker_ranks.size (), keys.size ())
560564 << " rank list and key list have unequal size" ;
561- char * buffer_ptr = reinterpret_cast <char *>(tensor.data_ptr ());
565+ char * buffer_ptr =
566+ reinterpret_cast <char *>(Backend::Get ()->GetAccessibleAddr (tensor));
562567 uint64_t data_size = tensor.numel () * tensor.element_size ();
563568 int chunk_size = data_size / worker_ranks.size ();
564569 PS_CHECK_EQ (data_size % worker_ranks.size (), 0 )
@@ -591,8 +596,14 @@ class AFTensorServer {
591596 .dtype (at::ScalarType (req_meta.dtype ))
592597 .memory_format (at::MemoryFormat::Contiguous)
593598 .device (Backend::Get ()->GetDevice ());
594- key_tensor.val =
595- at::from_blob (req_data.vals .data (), req_meta.shape , options);
599+ key_tensor.val = at::from_blob (
600+ Backend::Get ()->GetDeviceAddrFromHostPtr (
601+ req_data.vals .data (),
602+ std::accumulate (std::begin (req_meta.shape ),
603+ std::end (req_meta.shape ),
604+ c10::elementSize (at::ScalarType (req_meta.dtype )),
605+ std::multiplies<uint64_t >())),
606+ req_meta.shape , options);
596607 }
597608 key_tensor.key = req_data.keys [0 ];
598609 return key_tensor;
0 commit comments