Skip to content

Commit 1b6875e

Browse files
authored
[Offload] Full AMD support for olMemFill (llvm#154958)
1 parent aaae6ac commit 1b6875e

File tree

3 files changed

+210
-54
lines changed

3 files changed

+210
-54
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,7 @@ struct AMDGPUStreamTy {
924924
void *Dst;
925925
const void *Src;
926926
size_t Size;
927+
size_t NumTimes;
927928
};
928929

929930
/// Utility struct holding arguments for freeing buffers to memory managers.
@@ -974,9 +975,14 @@ struct AMDGPUStreamTy {
974975
StreamSlotTy() : Signal(nullptr), Callbacks({}), ActionArgs({}) {}
975976

976977
/// Schedule a host memory copy action on the slot.
977-
Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size) {
978+
///
979+
/// Num times will repeat the copy that many times, sequentually in the dest
980+
/// buffer.
981+
Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size,
982+
size_t NumTimes = 1) {
978983
Callbacks.emplace_back(memcpyAction);
979-
ActionArgs.emplace_back().MemcpyArgs = MemcpyArgsTy{Dst, Src, Size};
984+
ActionArgs.emplace_back().MemcpyArgs =
985+
MemcpyArgsTy{Dst, Src, Size, NumTimes};
980986
return Plugin::success();
981987
}
982988

@@ -1216,7 +1222,11 @@ struct AMDGPUStreamTy {
12161222
assert(Args->Dst && "Invalid destination buffer");
12171223
assert(Args->Src && "Invalid source buffer");
12181224

1219-
std::memcpy(Args->Dst, Args->Src, Args->Size);
1225+
auto BasePtr = Args->Dst;
1226+
for (size_t I = 0; I < Args->NumTimes; I++) {
1227+
std::memcpy(BasePtr, Args->Src, Args->Size);
1228+
BasePtr = reinterpret_cast<uint8_t *>(BasePtr) + Args->Size;
1229+
}
12201230

12211231
return Plugin::success();
12221232
}
@@ -1421,7 +1431,8 @@ struct AMDGPUStreamTy {
14211431
/// manager once the operation completes.
14221432
Error pushMemoryCopyH2DAsync(void *Dst, const void *Src, void *Inter,
14231433
uint64_t CopySize,
1424-
AMDGPUMemoryManagerTy &MemoryManager) {
1434+
AMDGPUMemoryManagerTy &MemoryManager,
1435+
size_t NumTimes = 1) {
14251436
// Retrieve available signals for the operation's outputs.
14261437
AMDGPUSignalTy *OutputSignals[2] = {};
14271438
if (auto Err = SignalManager.getResources(/*Num=*/2, OutputSignals))
@@ -1443,7 +1454,8 @@ struct AMDGPUStreamTy {
14431454
// The std::memcpy is done asynchronously using an async handler. We store
14441455
// the function's information in the action but it is not actually a
14451456
// post action.
1446-
if (auto Err = Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize))
1457+
if (auto Err =
1458+
Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize, NumTimes))
14471459
return Err;
14481460

14491461
// Make changes on this slot visible to the async handler's thread.
@@ -1464,7 +1476,11 @@ struct AMDGPUStreamTy {
14641476
std::tie(Curr, InputSignal) = consume(OutputSignal);
14651477
} else {
14661478
// All preceding operations completed, copy the memory synchronously.
1467-
std::memcpy(Inter, Src, CopySize);
1479+
auto *InterPtr = Inter;
1480+
for (size_t I = 0; I < NumTimes; I++) {
1481+
std::memcpy(InterPtr, Src, CopySize);
1482+
InterPtr = reinterpret_cast<uint8_t *>(InterPtr) + CopySize;
1483+
}
14681484

14691485
// Return the second signal because it will not be used.
14701486
OutputSignals[1]->decreaseUseCount();
@@ -1481,11 +1497,11 @@ struct AMDGPUStreamTy {
14811497
if (InputSignal && InputSignal->load()) {
14821498
hsa_signal_t InputSignalRaw = InputSignal->get();
14831499
return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter,
1484-
Agent, CopySize, 1, &InputSignalRaw,
1485-
OutputSignal->get());
1500+
Agent, CopySize * NumTimes, 1,
1501+
&InputSignalRaw, OutputSignal->get());
14861502
}
14871503
return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter,
1488-
Agent, CopySize, 0, nullptr,
1504+
Agent, CopySize * NumTimes, 0, nullptr,
14891505
OutputSignal->get());
14901506
}
14911507

@@ -2611,26 +2627,73 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
26112627
Error dataFillImpl(void *TgtPtr, const void *PatternPtr, int64_t PatternSize,
26122628
int64_t Size,
26132629
AsyncInfoWrapperTy &AsyncInfoWrapper) override {
2614-
hsa_status_t Status;
2630+
// Fast case, where we can use the 4 byte hsa_amd_memory_fill
2631+
if (Size % 4 == 0 &&
2632+
(PatternSize == 4 || PatternSize == 2 || PatternSize == 1)) {
2633+
uint32_t Pattern;
2634+
if (PatternSize == 1) {
2635+
auto *Byte = reinterpret_cast<const uint8_t *>(PatternPtr);
2636+
Pattern = *Byte | *Byte << 8 | *Byte << 16 | *Byte << 24;
2637+
} else if (PatternSize == 2) {
2638+
auto *Word = reinterpret_cast<const uint16_t *>(PatternPtr);
2639+
Pattern = *Word | (*Word << 16);
2640+
} else if (PatternSize == 4) {
2641+
Pattern = *reinterpret_cast<const uint32_t *>(PatternPtr);
2642+
} else {
2643+
// Shouldn't be here if the pattern size is outwith those values
2644+
llvm_unreachable("Invalid pattern size");
2645+
}
26152646

2616-
// We can use hsa_amd_memory_fill for this size, but it's not async so the
2617-
// queue needs to be synchronized first
2618-
if (PatternSize == 4) {
2619-
if (AsyncInfoWrapper.hasQueue())
2620-
if (auto Err = synchronize(AsyncInfoWrapper))
2647+
if (hasPendingWorkImpl(AsyncInfoWrapper)) {
2648+
AMDGPUStreamTy *Stream = nullptr;
2649+
if (auto Err = getStream(AsyncInfoWrapper, Stream))
26212650
return Err;
2622-
Status = hsa_amd_memory_fill(TgtPtr,
2623-
*static_cast<const uint32_t *>(PatternPtr),
2624-
Size / PatternSize);
26252651

2626-
if (auto Err =
2627-
Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n"))
2628-
return Err;
2629-
} else {
2630-
// TODO: Implement for AMDGPU. Most likely by doing the fill in pinned
2631-
// memory and copying to the device in one go.
2632-
return Plugin::error(ErrorCode::UNSUPPORTED, "Unsupported fill size");
2652+
struct MemFillArgsTy {
2653+
void *Dst;
2654+
uint32_t Pattern;
2655+
int64_t Size;
2656+
};
2657+
auto *Args = new MemFillArgsTy{TgtPtr, Pattern, Size / 4};
2658+
auto Fill = [](void *Data) {
2659+
MemFillArgsTy *Args = reinterpret_cast<MemFillArgsTy *>(Data);
2660+
assert(Args && "Invalid arguments");
2661+
2662+
auto Status =
2663+
hsa_amd_memory_fill(Args->Dst, Args->Pattern, Args->Size);
2664+
delete Args;
2665+
auto Err =
2666+
Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n");
2667+
if (Err) {
2668+
FATAL_MESSAGE(1, "error performing async fill: %s",
2669+
toString(std::move(Err)).data());
2670+
}
2671+
};
2672+
2673+
// hsa_amd_memory_fill doesn't signal completion using a signal, so use
2674+
// the existing host callback logic to handle that instead
2675+
return Stream->pushHostCallback(Fill, Args);
2676+
} else {
2677+
// If there is no pending work, do the fill synchronously
2678+
auto Status = hsa_amd_memory_fill(TgtPtr, Pattern, Size / 4);
2679+
return Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n");
2680+
}
26332681
}
2682+
2683+
// Slow case; allocate an appropriate memory size and enqueue copies
2684+
void *PinnedPtr = nullptr;
2685+
AMDGPUMemoryManagerTy &PinnedMemoryManager =
2686+
HostDevice.getPinnedMemoryManager();
2687+
if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedPtr))
2688+
return Err;
2689+
2690+
AMDGPUStreamTy *Stream = nullptr;
2691+
if (auto Err = getStream(AsyncInfoWrapper, Stream))
2692+
return Err;
2693+
2694+
return Stream->pushMemoryCopyH2DAsync(TgtPtr, PatternPtr, PinnedPtr,
2695+
PatternSize, PinnedMemoryManager,
2696+
Size / PatternSize);
26342697
}
26352698

26362699
/// Initialize the async info for interoperability purposes.

offload/unittests/OffloadAPI/common/Fixtures.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,40 @@ template <typename Fn> inline void threadify(Fn body) {
8989
}
9090
}
9191

92+
/// Enqueues a task to the queue that can be manually resolved.
93+
// It will block until `trigger` is called.
94+
struct ManuallyTriggeredTask {
95+
std::mutex M;
96+
std::condition_variable CV;
97+
bool Flag = false;
98+
ol_event_handle_t CompleteEvent;
99+
100+
ol_result_t enqueue(ol_queue_handle_t Queue) {
101+
if (auto Err = olLaunchHostFunction(
102+
Queue,
103+
[](void *That) {
104+
static_cast<ManuallyTriggeredTask *>(That)->wait();
105+
},
106+
this))
107+
return Err;
108+
109+
return olCreateEvent(Queue, &CompleteEvent);
110+
}
111+
112+
void wait() {
113+
std::unique_lock<std::mutex> lk(M);
114+
CV.wait_for(lk, std::chrono::milliseconds(1000), [&] { return Flag; });
115+
EXPECT_TRUE(Flag);
116+
}
117+
118+
ol_result_t trigger() {
119+
Flag = true;
120+
CV.notify_one();
121+
122+
return olSyncEvent(CompleteEvent);
123+
}
124+
};
125+
92126
struct OffloadTest : ::testing::Test {
93127
ol_device_handle_t Host = TestEnvironment::getHostDevice();
94128
};

0 commit comments

Comments
 (0)