Skip to content

Commit 2852a79

Browse files
committed
add shmem pluggable allocator
1 parent ebf10b3 commit 2852a79

File tree

10 files changed

+2742
-0
lines changed

10 files changed

+2742
-0
lines changed

build.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ BUILD_DEEPEP_MODULE="ON"
55
BUILD_DEEPEP_OPS="ON"
66
BUILD_KERNELS_MODULE="ON"
77
BUILD_MEMORY_SAVER_MODULE="ON"
8+
BUILD_SHMEM_ALLOCATOR_MODULE="ON"
89

910
ONLY_BUILD_DEEPEP_ADAPTER_MODULE="OFF"
1011
ONLY_BUILD_DEEPEP_KERNELs_MODULE="OFF"
1112
ONLY_BUILD_MEMORY_SAVER_MODULE="OFF"
13+
ONLY_BUILD_SHMEM_ALLOCATOR_MODULE="OFF"
1214

1315
DEBUG_MODE="OFF"
1416

@@ -18,6 +20,7 @@ while getopts ":a:hd" opt; do
1820
BUILD_DEEPEP_MODULE="OFF"
1921
BUILD_KERNELS_MODULE="OFF"
2022
BUILD_MEMORY_SAVER_MODULE="OFF"
23+
BUILD_SHMEM_ALLOCATOR_MODULE="OFF"
2124
case "$OPTARG" in
2225
deepep )
2326
BUILD_DEEPEP_MODULE="ON"
@@ -42,6 +45,10 @@ while getopts ":a:hd" opt; do
4245
BUILD_MEMORY_SAVER_MODULE="ON"
4346
ONLY_BUILD_MEMORY_SAVER_MODULE="ON"
4447
;;
48+
shmem-allocator )
49+
BUILD_SHMEM_ALLOCATOR_MODULE="ON"
50+
ONLY_BUILD_SHMEM_ALLOCATOR_MODULE="ON"
51+
;;
4552
* )
4653
echo "Error: Invalid Value"
4754
echo "Allowed value: deepep|kernels|deepep-adapter|deepep-kernels|memory-saver"
@@ -61,6 +68,7 @@ while getopts ":a:hd" opt; do
6168
echo " deepep-adapter Only build deepep adapter layer and use old build of deepep kernels."
6269
echo " deepep-kernels Only build deepep kernels and use old build of deepep adapter layer."
6370
echo " memory-saver Only build torch_memory_saver (under contrib)."
71+
echo " shmem-allocator Only build torch-shmem-allocator (under contrib)."
6472
exit 1
6573
;;
6674
\? )
@@ -113,6 +121,7 @@ function build_kernels()
113121
{
114122
if [[ "$ONLY_BUILD_DEEPEP_KERNELs_MODULE" == "ON" ]]; then return 0; fi
115123
if [[ "$ONLY_BUILD_MEMORY_SAVER_MODULE" == "ON" ]]; then return 0; fi
124+
if [[ "$ONLY_BUILD_SHMEM_ALLOCATOR_MODULE" == "ON" ]]; then return 0; fi
116125

117126
CMAKE_DIR=""
118127
BUILD_DIR="build"
@@ -172,6 +181,20 @@ function build_memory_saver()
172181
cd -
173182
}
174183

184+
function build_shmem_allocator()
185+
{
186+
if [[ "$BUILD_SHMEM_ALLOCATOR_MODULE" != "ON" ]]; then return 0; fi
187+
echo "[shmem_allocator] Building shmem_allocator via setup.py"
188+
cd contrib/shmem_allocator/python || exit
189+
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/build
190+
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/dist
191+
python3 setup.py clean --all
192+
python3 setup.py bdist_wheel
193+
mv -v "$CURRENT_DIR"/contrib/shmem_allocator/python/dist/shmem_allocator*.whl "${OUTPUT_DIR}/"
194+
rm -rf "$CURRENT_DIR"/contrib/shmem_allocator/python/dist
195+
cd -
196+
}
197+
175198
function make_deepep_package()
176199
{
177200
cd python/deep_ep || exit
@@ -209,6 +232,7 @@ function main()
209232
pip3 install wheel==0.45.1
210233
fi
211234
build_memory_saver
235+
build_shmem_allocator
212236
if [[ "$BUILD_DEEPEP_MODULE" == "ON" ]]; then
213237
make_deepep_package
214238
fi

contrib/shmem_allocator/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Torch SHMEM Pluggable Allocator
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
#include "NpuCachingCustomAllocator.h"
2+
3+
#include <algorithm>
4+
#include <bitset>
5+
#include <cstdlib>
6+
#include <deque>
7+
#include <map>
8+
#include <memory>
9+
#include <mutex>
10+
#include <regex>
11+
#include <set>
12+
#include <string>
13+
#include <vector>
14+
15+
std::mutex *NpuCachingCustomAllocator::getFreeMutex() const {
16+
static std::mutex npu_free_mutex;
17+
return &npu_free_mutex;
18+
}
19+
20+
Block *NpuCachingCustomAllocator::get_allocated_block(void *ptr, bool remove) {
21+
std::lock_guard<std::mutex> lock(mutex);
22+
auto it = allocated_blocks.find(ptr);
23+
if (it == allocated_blocks.end()) {
24+
return nullptr;
25+
}
26+
Block *block = it->second;
27+
if (remove) {
28+
allocated_blocks.erase(it);
29+
}
30+
return block;
31+
}
32+
33+
void NpuCachingCustomAllocator::init(int device_count) {
34+
int max_device_count = 1000000;
35+
TORCH_INTERNAL_ASSERT(device_count < max_device_count,
36+
"Error, out of maximum device");
37+
int size = static_cast<int>(device_allocator.size());
38+
if (size < device_count) {
39+
device_allocator.resize(device_count);
40+
for (const auto i : c10::irange(size, device_count)) {
41+
device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
42+
}
43+
}
44+
45+
static bool registered = false;
46+
if (!registered) {
47+
std::atexit(finalize);
48+
registered = true;
49+
}
50+
}
51+
52+
bool NpuCachingCustomAllocator::initialized() {
53+
return !device_allocator.empty();
54+
}
55+
56+
/** allocates a block which is safe to use from the provided stream */
57+
void *NpuCachingCustomAllocator::malloc(int device, size_t size,
58+
aclrtStream stream) {
59+
TORCH_INTERNAL_ASSERT(
60+
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
61+
"device index out of range.");
62+
Block *block = device_allocator[device]->malloc(device, size, stream);
63+
TORCH_CHECK(block, "Allocate Block failed.");
64+
add_allocated_block(block);
65+
void *devPtr = static_cast<void *>(block->ptr);
66+
return devPtr;
67+
}
68+
69+
void NpuCachingCustomAllocator::free(void *ptr) {
70+
if (!ptr) {
71+
return;
72+
}
73+
Block *block = get_allocated_block(ptr, true);
74+
if (!block) {
75+
AT_ERROR("invalid device pointer: ", ptr);
76+
}
77+
TORCH_INTERNAL_ASSERT(
78+
0 <= block->device &&
79+
static_cast<size_t>(block->device) < device_allocator.size(),
80+
"device index out of range.");
81+
device_allocator[block->device]->free(block);
82+
}
83+
84+
void NpuCachingCustomAllocator::emptyCache(bool check_error) {
85+
int count = static_cast<int>(device_allocator.size());
86+
for (int i = 0; i < count; i++) device_allocator[i]->emptyCache(check_error);
87+
}
88+
89+
void NpuCachingCustomAllocator::assertValidDevice(int device) {
90+
int device_num = c10_npu::device_count();
91+
AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument.");
92+
}
93+
94+
DeviceStats NpuCachingCustomAllocator::getDeviceStats(int device) {
95+
assertValidDevice(device);
96+
return device_allocator[device]->getStats();
97+
}
98+
99+
void NpuCachingCustomAllocator::resetPeakStats(int device) {
100+
assertValidDevice(device);
101+
device_allocator[device]->resetPeakStats();
102+
}
103+
104+
std::string NpuCachingCustomAllocator::name() { return "native"; }
105+
106+
void CachingAllocatorConfig::lexArgs(const char *env,
107+
std::vector<std::string> &config) {
108+
std::vector<char> buf;
109+
110+
size_t env_length = strlen(env);
111+
for (size_t i = 0; i < env_length; i++) {
112+
if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
113+
if (!buf.empty()) {
114+
config.emplace_back(buf.begin(), buf.end());
115+
buf.clear();
116+
}
117+
config.emplace_back(1, env[i]);
118+
} else if (env[i] != ' ') {
119+
buf.emplace_back(static_cast<char>(env[i]));
120+
}
121+
}
122+
if (!buf.empty()) {
123+
config.emplace_back(buf.begin(), buf.end());
124+
}
125+
}
126+
127+
void CachingAllocatorConfig::consumeToken(
128+
const std::vector<std::string> &config, size_t i, const char c) {
129+
TORCH_CHECK(i < config.size() && config[i].compare(std::string(1, c)) == 0,
130+
"Error parsing CachingAllocator settings, expected ", c);
131+
}
132+
133+
size_t CachingAllocatorConfig::parseMaxSplitSize(
134+
const std::vector<std::string> &config, size_t i) {
135+
consumeToken(config, ++i, ':');
136+
if (++i < config.size()) {
137+
size_t val1 = 0;
138+
try {
139+
val1 = static_cast<size_t>(stoi(config[i]));
140+
} catch (const std::invalid_argument &e) {
141+
TORCH_CHECK(false, "Error, expecting digit string in config");
142+
} catch (const std::out_of_range &e) {
143+
TORCH_CHECK(false, "Error, out of int range");
144+
}
145+
TORCH_CHECK(
146+
val1 > kLargeBuffer / kUnitMB,
147+
"CachingAllocator option max_split_size_mb too small, must be > ",
148+
kLargeBuffer / kUnitMB);
149+
val1 = std::max(val1, kLargeBuffer / kUnitMB);
150+
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / kUnitMB));
151+
m_max_split_size = val1 * kUnitMB;
152+
} else {
153+
TORCH_CHECK(false, "Error, expecting max_split_size_mb value");
154+
}
155+
return i;
156+
}
157+
158+
size_t CachingAllocatorConfig::parseGarbageCollectionThreshold(
159+
const std::vector<std::string> &config, size_t i) {
160+
consumeToken(config, ++i, ':');
161+
if (++i < config.size()) {
162+
double val1 = 0.0;
163+
try {
164+
val1 = stod(config[i]);
165+
} catch (const std::invalid_argument &e) {
166+
TORCH_CHECK(false, "Error, expecting digital string in config");
167+
} catch (const std::out_of_range &e) {
168+
TORCH_CHECK(false, "Error, out of double range");
169+
}
170+
TORCH_CHECK(val1 > 0,
171+
"garbage_collect_threshold too small, set it 0.0~1.0");
172+
TORCH_CHECK(val1 < 1.0,
173+
"garbage_collect_threshold too big, set it 0.0~1.0");
174+
m_garbage_collection_threshold = val1;
175+
} else {
176+
TORCH_CHECK(false, "Error, expecting garbage_collection_threshold value");
177+
}
178+
return i;
179+
}
180+
181+
size_t CachingAllocatorConfig::parseExpandableSegments(
182+
const std::vector<std::string> &config, size_t i) {
183+
consumeToken(config, ++i, ':');
184+
if (++i < config.size()) {
185+
TORCH_CHECK(
186+
i < config.size() && (config[i] == "True" || config[i] == "False"),
187+
"Expected a single True/False argument for expandable_segments");
188+
m_expandable_segments = (config[i] == "True");
189+
if (m_expandable_segments) {
190+
void *ptr = nullptr;
191+
constexpr size_t virtual_mem_size = 512;
192+
auto status = aclrtReserveMemAddress(&ptr, virtual_mem_size, 0, NULL, 1);
193+
if (status == ACL_ERROR_NONE) {
194+
TORCH_CHECK(aclrtReleaseMemAddress(ptr) == ACL_ERROR_NONE,
195+
"aclrtReleaseMemAddress failed.");
196+
} else {
197+
NPU_CHECK_SUPPORT_OR_ERROR(status);
198+
m_expandable_segments = false;
199+
}
200+
}
201+
} else {
202+
TORCH_CHECK(false, "Error, expecting expandable_segments value");
203+
}
204+
return i;
205+
}
206+
207+
void CachingAllocatorConfig::parseArgs(const char *env) {
208+
// If empty, set the default values
209+
m_max_split_size = std::numeric_limits<size_t>::max();
210+
m_garbage_collection_threshold = 0;
211+
212+
if (env == nullptr) {
213+
return;
214+
}
215+
216+
std::vector<std::string> config;
217+
lexArgs(env, config);
218+
219+
for (size_t i = 0; i < config.size(); i++) {
220+
if (config[i].compare("max_split_size_mb") == 0) {
221+
i = parseMaxSplitSize(config, i);
222+
} else if (config[i].compare("garbage_collection_threshold") == 0) {
223+
i = parseGarbageCollectionThreshold(config, i);
224+
} else if (config[i] == "expandable_segments") {
225+
set_expandable_segments_flag = true;
226+
i = parseExpandableSegments(config, i);
227+
} else {
228+
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", config[i]);
229+
}
230+
231+
if (i + 1 < config.size()) {
232+
consumeToken(config, ++i, ',');
233+
}
234+
}
235+
if (m_expandable_segments) {
236+
if (set_expandable_segments_flag) {
237+
} else if (m_max_split_size != std::numeric_limits<size_t>::max() ||
238+
m_garbage_collection_threshold != 0) {
239+
m_expandable_segments = false;
240+
}
241+
}
242+
}
243+
244+
NpuCachingCustomAllocator my_allocator;
245+
246+
void local_raw_delete(void *ptr) { my_allocator.free(ptr); }
247+
248+
void finalize() {
249+
// uninit shmem handle(need be done in collective)
250+
// for (const auto i : c10::irange(0, shm_ptr_meta.size())) {
251+
// shmem_free(shm_ptr_meta[i]);
252+
// }
253+
auto status = shmem_finalize();
254+
}

0 commit comments

Comments
 (0)