1616ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size)
1717 : hContext(hContext), size(size) {}
1818
19- ur_host_mem_handle_t ::ur_host_mem_handle_t (ur_context_handle_t hContext,
20- void *hostPtr, size_t size,
21- host_ptr_action_t hostPtrAction)
19+ ur_usm_handle_t_::ur_usm_handle_t_ (ur_context_handle_t hContext, size_t size,
20+ const void *ptr)
21+ : ur_mem_handle_t_(hContext, size), ptr(const_cast <void *>(ptr)) {}
22+
23+ ur_usm_handle_t_::~ur_usm_handle_t_ () {}
24+
25+ void *ur_usm_handle_t_::getDevicePtr (
26+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
27+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
28+ std::ignore = hDevice;
29+ std::ignore = access;
30+ std::ignore = offset;
31+ std::ignore = size;
32+ std::ignore = migrate;
33+ return ptr;
34+ }
35+
36+ void *ur_usm_handle_t_::mapHostPtr (
37+ access_mode_t access, size_t offset, size_t size,
38+ std::function<void (void *src, void *dst, size_t )>) {
39+ std::ignore = access;
40+ std::ignore = offset;
41+ std::ignore = size;
42+ return ptr;
43+ }
44+
45+ void ur_usm_handle_t_::unmapHostPtr (
46+ void *pMappedPtr, std::function<void (void *src, void *dst, size_t )>) {
47+ std::ignore = pMappedPtr;
48+ /* nop */
49+ }
50+
51+ ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
52+ ur_context_handle_t hContext, void *hostPtr, size_t size,
53+ host_ptr_action_t hostPtrAction)
2254 : ur_mem_handle_t_(hContext, size) {
2355 bool hostPtrImported = false ;
2456 if (hostPtrAction == host_ptr_action_t ::import ) {
@@ -37,7 +69,7 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
3769 }
3870}
3971
40- ur_host_mem_handle_t ::~ur_host_mem_handle_t () {
72+ ur_integrated_mem_handle_t ::~ur_integrated_mem_handle_t () {
4173 if (ptr) {
4274 auto ret = hContext->getDefaultUSMPool ()->free (ptr);
4375 if (ret != UR_RESULT_SUCCESS) {
@@ -46,21 +78,36 @@ ur_host_mem_handle_t::~ur_host_mem_handle_t() {
4678 }
4779}
4880
49- void *ur_host_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
81+ void *ur_integrated_mem_handle_t ::getDevicePtr(
82+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
83+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
5084 std::ignore = hDevice;
85+ std::ignore = access;
86+ std::ignore = offset;
87+ std::ignore = size;
88+ std::ignore = migrate;
5189 return ptr;
5290}
5391
54- ur_result_t ur_device_mem_handle_t::migrateBufferTo (ur_device_handle_t hDevice,
55- void *src, size_t size) {
56- auto Id = hDevice->Id .value ();
92+ void *ur_integrated_mem_handle_t ::mapHostPtr(
93+ access_mode_t access, size_t offset, size_t size,
94+ std::function<void (void *src, void *dst, size_t )> migrate) {
95+ std::ignore = access;
96+ std::ignore = offset;
97+ std::ignore = size;
98+ std::ignore = migrate;
99+ return ptr;
100+ }
57101
58- if (!deviceAllocations[Id]) {
59- UR_CALL (hContext-> getDefaultUSMPool ()-> allocate (hContext, hDevice, nullptr ,
60- UR_USM_TYPE_DEVICE, size,
61- &deviceAllocations[Id]));
62- }
102+ void ur_integrated_mem_handle_t::unmapHostPtr (
103+ void *pMappedPtr, std::function< void ( void *src, void *dst, size_t )>) {
104+ std::ignore = pMappedPtr;
105+ /* nop */
106+ }
63107
108+ static ur_result_t synchronousZeCopy (ur_context_handle_t hContext,
109+ ur_device_handle_t hDevice, void *dst,
110+ const void *src, size_t size) {
64111 auto commandList = hContext->commandListCache .getImmediateCommandList (
65112 hDevice->ZeDevice , true ,
66113 hDevice
@@ -70,26 +117,42 @@ ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
70117 std::nullopt );
71118
72119 ZE2UR_CALL (zeCommandListAppendMemoryCopy,
73- (commandList.get (), deviceAllocations[Id], src, size, nullptr , 0 ,
74- nullptr ));
120+ (commandList.get (), dst, src, size, nullptr , 0 , nullptr ));
121+
122+ return UR_RESULT_SUCCESS;
123+ }
124+
125+ ur_result_t
126+ ur_discrete_mem_handle_t ::migrateBufferTo(ur_device_handle_t hDevice, void *src,
127+ size_t size) {
128+ auto Id = hDevice->Id .value ();
129+
130+ if (!deviceAllocations[Id]) {
131+ UR_CALL (hContext->getDefaultUSMPool ()->allocate (hContext, hDevice, nullptr ,
132+ UR_USM_TYPE_DEVICE, size,
133+ &deviceAllocations[Id]));
134+ }
135+
136+ UR_CALL (
137+ synchronousZeCopy (hContext, hDevice, deviceAllocations[Id], src, size));
75138
76139 activeAllocationDevice = hDevice;
77140
78141 return UR_RESULT_SUCCESS;
79142}
80143
81- ur_device_mem_handle_t :: ur_device_mem_handle_t (ur_context_handle_t hContext,
82- void *hostPtr, size_t size)
144+ ur_discrete_mem_handle_t :: ur_discrete_mem_handle_t (ur_context_handle_t hContext,
145+ void *hostPtr, size_t size)
83146 : ur_mem_handle_t_(hContext, size),
84147 deviceAllocations (hContext->getPlatform ()->getNumDevices()),
85- activeAllocationDevice(nullptr ) {
148+ activeAllocationDevice(nullptr ), hostAllocations() {
86149 if (hostPtr) {
87150 auto initialDevice = hContext->getDevices ()[0 ];
88151 UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
89152 }
90153}
91154
92- ur_device_mem_handle_t ::~ur_device_mem_handle_t () {
155+ ur_discrete_mem_handle_t ::~ur_discrete_mem_handle_t () {
93156 for (auto &ptr : deviceAllocations) {
94157 if (ptr) {
95158 auto ret = hContext->getDefaultUSMPool ()->free (ptr);
@@ -100,8 +163,12 @@ ur_device_mem_handle_t::~ur_device_mem_handle_t() {
100163 }
101164}
102165
103- void *ur_device_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
104- std::lock_guard lock (this ->Mutex );
166+ void *ur_discrete_mem_handle_t ::getDevicePtrUnlocked(
167+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
168+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
169+ std::ignore = access;
170+ std::ignore = size;
171+ std::ignore = migrate;
105172
106173 if (!activeAllocationDevice) {
107174 UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
@@ -110,8 +177,10 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
110177 activeAllocationDevice = hDevice;
111178 }
112179
180+ char *ptr;
113181 if (activeAllocationDevice == hDevice) {
114- return deviceAllocations[hDevice->Id .value ()];
182+ ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()]);
183+ return ptr + offset;
115184 }
116185
117186 auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -124,7 +193,71 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
124193 }
125194
126195 // TODO: see if it's better to migrate the memory to the specified device
127- return deviceAllocations[activeAllocationDevice->Id .value ()];
196+ return ur_cast<char *>(
197+ deviceAllocations[activeAllocationDevice->Id .value ()]) +
198+ offset;
199+ }
200+
201+ void *ur_discrete_mem_handle_t ::getDevicePtr(
202+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
203+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
204+ std::lock_guard lock (this ->Mutex );
205+ return getDevicePtrUnlocked (hDevice, access, offset, size, migrate);
206+ }
207+
208+ void *ur_discrete_mem_handle_t ::mapHostPtr(
209+ access_mode_t access, size_t offset, size_t size,
210+ std::function<void (void *src, void *dst, size_t )> migrate) {
211+ std::lock_guard lock (this ->Mutex );
212+
213+ // TODO: use async alloc?
214+
215+ void *ptr;
216+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
217+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
218+
219+ hostAllocations.emplace_back (ptr, size, offset, access);
220+
221+ if (activeAllocationDevice && access != access_mode_t ::write_only) {
222+ auto srcPtr =
223+ ur_cast<char *>(deviceAllocations[activeAllocationDevice->Id .value ()]) +
224+ offset;
225+ migrate (srcPtr, hostAllocations.back ().ptr , size);
226+ }
227+
228+ return hostAllocations.back ().ptr ;
229+ }
230+
231+ void ur_discrete_mem_handle_t::unmapHostPtr (
232+ void *pMappedPtr,
233+ std::function<void (void *src, void *dst, size_t )> migrate) {
234+ std::lock_guard lock (this ->Mutex );
235+
236+ for (auto &hostAllocation : hostAllocations) {
237+ if (hostAllocation.ptr == pMappedPtr) {
238+ void *devicePtr = nullptr ;
239+ if (activeAllocationDevice) {
240+ devicePtr = ur_cast<char *>(
241+ deviceAllocations[activeAllocationDevice->Id .value ()]) +
242+ hostAllocation.offset ;
243+ } else if (hostAllocation.access != access_mode_t ::write_invalidate) {
244+ devicePtr = ur_cast<char *>(getDevicePtrUnlocked (
245+ hContext->getDevices ()[0 ], access_mode_t ::read_only,
246+ hostAllocation.offset , hostAllocation.size , migrate));
247+ }
248+
249+ if (devicePtr) {
250+ migrate (hostAllocation.ptr , devicePtr, hostAllocation.size );
251+ }
252+
253+ // TODO: use async free here?
254+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
255+ return ;
256+ }
257+ }
258+
259+ // No mapping found
260+ throw UR_RESULT_ERROR_INVALID_ARGUMENT;
128261}
129262
130263namespace ur ::level_zero {
@@ -155,13 +288,14 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
155288 if (useHostBuffer) {
156289 // TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
157290 // or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
158- auto hostPtrAction = flags & UR_MEM_FLAG_USE_HOST_POINTER
159- ? ur_host_mem_handle_t ::host_ptr_action_t ::import
160- : ur_host_mem_handle_t ::host_ptr_action_t ::copy;
291+ auto hostPtrAction =
292+ flags & UR_MEM_FLAG_USE_HOST_POINTER
293+ ? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
294+ : ur_integrated_mem_handle_t ::host_ptr_action_t ::copy;
161295 *phBuffer =
162- new ur_host_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
296+ new ur_integrated_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
163297 } else {
164- *phBuffer = new ur_device_mem_handle_t (hContext, hostPtr, size);
298+ *phBuffer = new ur_discrete_mem_handle_t (hContext, hostPtr, size);
165299 }
166300
167301 return UR_RESULT_SUCCESS;
0 commit comments