Skip to content

Commit e6afb51

Browse files
muchulee8pytorchmergebot
authored andcommitted
[AOTInductor] Free folded constants that's managed by AOTInductor (pytorch#149825)
internally. Summary: This diff allows freeing the usage of folded constants that's created by AOTInductor through CUDACachingAllocator instead of the constant blob from cudaMalloc directly. Test Plan: LD_LIBRARY_PATH=/data/users/$USER/pytorch/build/lib /home/$USER/local/pytorch/build/bin/test_aoti_inference Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#149825 Approved by: https://github.com/chenyang78, https://github.com/desertfire, https://github.com/jingsh
1 parent e080bac commit e6afb51

File tree

4 files changed

+137
-37
lines changed

4 files changed

+137
-37
lines changed

test/cpp/aoti_inference/test.cpp

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
1414
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
1515
#if defined(USE_CUDA)
16+
#include <c10/cuda/CUDACachingAllocator.h>
1617
#include <cuda_runtime.h>
1718
#endif
1819
#if defined(USE_CUDA) || defined(USE_ROCM)
@@ -327,20 +328,26 @@ void test_aoti_double_buffering_with_tensor_constants() {
327328
ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
328329
}
329330

330-
void test_aoti_free_buffer() {
331+
void test_aoti_free_buffer(bool use_runtime_constant_folding) {
331332
torch::NoGradGuard no_grad;
333+
size_t allocated, reserved, active;
332334

333335
std::string data_path =
334336
(std::filesystem::path(
335337
STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "large_data.pt")
336338
.string();
337339

338340
// Memory information variable
339-
cudaError_t cudaStatus;
340341
size_t DATASIZE = 128 * 1024 * 1024; // We have 128MB of weight data.
342+
size_t FOLDEDDATASIZE = use_runtime_constant_folding
343+
? 64 * 1024 * 1024
344+
: 0; // We have 64MB of folded data.
341345

342346
torch::jit::script::Module data_loader = torch::jit::load(data_path);
343347
std::string path_attr = "model_so_path";
348+
if (use_runtime_constant_folding) {
349+
path_attr += std::string("_use_runtime_constant_folding");
350+
}
344351
std::string inputs_attr = "inputs";
345352
std::string outputs_attr = "outputs";
346353
std::string weights_attr = "w_pre";
@@ -365,7 +372,16 @@ void test_aoti_free_buffer() {
365372
runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
366373
model_so_path);
367374

368-
// We extract the initial memory here.
375+
// We extract the memory information starting from here.
376+
int device_idx = -1;
377+
cudaError_t cudaStatus;
378+
cudaStatus = cudaGetDevice(&device_idx);
379+
if (cudaStatus != cudaSuccess || device_idx == -1) {
380+
throw std::runtime_error("cudaGetDevice failed!");
381+
}
382+
c10::cuda::CUDACachingAllocator::DeviceStats stats =
383+
c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
384+
// This should contain one set of weight (128MB) loaded from .so
369385
size_t initMemory = 0;
370386
size_t totalMemory = 0;
371387
cudaStatus = cudaMemGetInfo(&initMemory, &totalMemory);
@@ -382,42 +398,83 @@ void test_aoti_free_buffer() {
382398
}
383399
ASSERT_EQ(initMemory - DATASIZE, updateMemory2);
384400

401+
// Call run, this should run const_fold and create the folded constant in #2
402+
// (64MB).
403+
if (use_runtime_constant_folding) {
404+
runner->run_const_fold(/* use_inactive = */ true);
405+
size_t constFoldMemory = 0;
406+
cudaStatus = cudaMemGetInfo(&constFoldMemory, &totalMemory);
407+
if (cudaStatus != cudaSuccess) {
408+
throw std::runtime_error("cudaMemGetInfo failed!");
409+
}
410+
ASSERT_EQ(initMemory - DATASIZE - FOLDEDDATASIZE, constFoldMemory);
411+
}
412+
385413
// We swap and free the inactive buffer. (Use #2 and free #1)
414+
// Note that buffer #1 do not include folded-const
386415
runner->swap_constant_buffer();
387416
runner->free_inactive_constant_buffer();
388417
size_t postFreeMemory = 0;
389418
cudaStatus = cudaMemGetInfo(&postFreeMemory, &totalMemory);
390419
if (cudaStatus != cudaSuccess) {
391420
throw std::runtime_error("cudaMemGetInfo failed!");
392421
}
393-
// We should only have one set of buffer (#2), memory used should equal
394-
// initial memory.
395-
ASSERT_EQ(initMemory, postFreeMemory);
422+
// We should only have one set of buffer (#2), available memory should equal
423+
// initial memory minus the folded constants.
424+
ASSERT_EQ(initMemory - FOLDEDDATASIZE, postFreeMemory);
396425

397-
// We update random weights to buffer #1.
426+
// We update random weights to buffer #1 and run const fold.
427+
// We will have 2 full set of data plus 2 set of const-folded data.
398428
runner->update_inactive_constant_buffer(rand_map);
429+
runner->run_const_fold(/* use_inactive = */ true);
399430
size_t updateMemory1 = 0;
400431
cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory);
401432
if (cudaStatus != cudaSuccess) {
402433
throw std::runtime_error("cudaMemGetInfo failed!");
403434
}
404-
ASSERT_EQ(initMemory - DATASIZE, updateMemory1);
405-
406-
// Test if we directly free the buffer #1.
435+
ASSERT_EQ(initMemory - DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1);
436+
437+
// We directly free the buffer #1. This would free the DATASIZE weight.
438+
// If folded constant exists, it will not directly free the cudaMalloc, but
439+
// decrease the active buffer in CachingAllocator instead.
440+
size_t active1, active2;
441+
size_t allocated1, allocated2;
442+
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
443+
active1 = stats.active_bytes[0].current;
444+
allocated1 = stats.allocated_bytes[0].current;
407445
runner->free_inactive_constant_buffer();
408446
cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory);
409447
if (cudaStatus != cudaSuccess) {
410448
throw std::runtime_error("cudaMemGetInfo failed!");
411449
}
412-
ASSERT_EQ(initMemory, updateMemory1);
450+
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
451+
active2 = stats.active_bytes[0].current;
452+
allocated2 = stats.allocated_bytes[0].current;
453+
ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1);
454+
ASSERT_EQ(FOLDEDDATASIZE, active1 - active2);
413455

414456
// Free buffer #1 again, since #1 is freed, nothing should change.
415457
runner->free_inactive_constant_buffer();
416458
cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory);
417459
if (cudaStatus != cudaSuccess) {
418460
throw std::runtime_error("cudaMemGetInfo failed!");
419461
}
420-
ASSERT_EQ(initMemory, updateMemory1);
462+
ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1);
463+
ASSERT_EQ(FOLDEDDATASIZE, active1 - active2);
464+
465+
// Swap and free #2, no data should exist in memory now.
466+
// However, the folded constants still occupies the CUDA memory in
467+
// CachedAllocator.
468+
runner->swap_constant_buffer();
469+
runner->free_inactive_constant_buffer();
470+
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
471+
active2 = stats.active_bytes[0].current;
472+
cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory);
473+
if (cudaStatus != cudaSuccess) {
474+
throw std::runtime_error("cudaMemGetInfo failed!");
475+
}
476+
ASSERT_EQ(initMemory + DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1);
477+
ASSERT_EQ(2 * FOLDEDDATASIZE, active1 - active2);
421478
}
422479

423480
class ThreadPool {
@@ -612,7 +669,11 @@ TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
612669
}
613670

614671
TEST(AotInductorTest, FreeInactiveConstantBufferCuda) {
615-
test_aoti_free_buffer();
672+
test_aoti_free_buffer(false);
673+
}
674+
675+
TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) {
676+
test_aoti_free_buffer(true);
616677
}
617678

618679
TEST(AotInductorTest, MultiStreamTestCuda) {

test/cpp/aoti_inference/test.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,29 +93,37 @@ def generate_large_tests():
9393
ref_output = model(x)
9494

9595
torch._dynamo.reset()
96-
with torch.no_grad():
97-
model_so_path = aot_compile(
98-
model,
99-
(x,),
100-
)
101-
# Also store a .pt2 file using the aoti_compile_and_package API
102-
pt2_package_path = torch._inductor.aoti_compile_and_package(
103-
torch.export.export(
96+
for use_runtime_constant_folding in [True, False]:
97+
with torch.no_grad():
98+
model_so_path = aot_compile(
10499
model,
105100
(x,),
106-
),
107-
)
101+
options={
102+
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
103+
},
104+
)
105+
# Also store a .pt2 file using the aoti_compile_and_package API
106+
pt2_package_path = torch._inductor.aoti_compile_and_package(
107+
torch.export.export(
108+
model,
109+
(x,),
110+
),
111+
inductor_configs={
112+
"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding
113+
},
114+
)
108115

109-
large_data.update(
110-
{ # noqa: F541
111-
"model_so_path": model_so_path,
112-
"pt2_package_path": pt2_package_path,
113-
"inputs": [x],
114-
"outputs": [ref_output],
115-
"w_pre": model.w_pre,
116-
"w_add": model.w_add,
117-
}
118-
)
116+
suffix = "_use_runtime_constant_folding" if use_runtime_constant_folding else ""
117+
large_data.update(
118+
{ # noqa: F541
119+
f"model_so_path{suffix}": model_so_path,
120+
f"pt2_package_path{suffix}": pt2_package_path,
121+
"inputs": [x],
122+
"outputs": [ref_output],
123+
"w_pre": model.w_pre,
124+
"w_add": model.w_add,
125+
}
126+
)
119127

120128

121129
# AOTI model which will create additional tensors during autograd.

torch/csrc/inductor/aoti_runtime/model.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,11 @@ class AOTInductorModelBase {
290290

291291
void load_constants() {
292292
size_t num_constants = this->num_constants();
293+
size_t num_folded_constants = this->num_folded_constants();
293294
constants_map_->reserve(num_constants);
294295

295-
std::vector<size_t> constants_internal_offset(num_constants);
296+
std::vector<size_t> constants_internal_offset(
297+
num_constants - num_folded_constants);
296298
size_t blob_size = 0;
297299
compute_constant_blob(blob_size, constants_internal_offset);
298300
#if defined(USE_CUDA) || defined(USE_XPU)
@@ -317,7 +319,7 @@ class AOTInductorModelBase {
317319
constants_internal_offset[i],
318320
bytes_read,
319321
data_size,
320-
from_folded)
322+
/* skip_copy = */ false)
321323
: nullptr;
322324
bytes_read += data_size;
323325

@@ -401,13 +403,17 @@ class AOTInductorModelBase {
401403
std::vector<size_t>& constants_internal_offset) {
402404
size_t num_constants = this->num_constants();
403405
blob_size = 0;
406+
size_t curr_idx = 0;
404407
for (size_t i = 0; i < num_constants; i++) {
408+
if (this->constant_from_folded(i)) {
409+
continue;
410+
}
405411
size_t data_size = this->constant_data_size(i);
406412
if (data_size % AOTI_CONST_ALIGNMENT) {
407413
data_size = AOTI_CONST_ALIGNMENT +
408414
(data_size / AOTI_CONST_ALIGNMENT) * AOTI_CONST_ALIGNMENT;
409415
}
410-
constants_internal_offset[i] = blob_size;
416+
constants_internal_offset[curr_idx++] = blob_size;
411417
blob_size += data_size;
412418
}
413419
}
@@ -424,6 +430,17 @@ class AOTInductorModelBase {
424430
return constants_info_.size();
425431
}
426432

433+
size_t num_folded_constants() const {
434+
size_t total_consts = this->num_constants();
435+
size_t folded_consts = 0;
436+
for (size_t i = 0; i < total_consts; i++) {
437+
if (this->constant_from_folded(i)) {
438+
folded_consts++;
439+
}
440+
}
441+
return folded_consts;
442+
}
443+
427444
const char* input_name(int64_t idx) const {
428445
return inputs_info_.at(idx).name;
429446
}

torch/csrc/inductor/aoti_runtime/model_container.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class AOTInductorModelContainer {
5353
}
5454
model->load_constants();
5555
constant_blob_ = model->release_constant_blob();
56-
constants_internal_offset_.resize(model->num_constants());
56+
constants_internal_offset_.resize(
57+
model->num_constants() - model->num_folded_constants());
5758
model->compute_constant_blob(blob_size_, constants_internal_offset_);
5859

5960
for (auto& model : models_) {
@@ -453,6 +454,19 @@ class AOTInductorModelContainer {
453454
} else {
454455
constant_blob_secondary_.reset();
455456
}
457+
// Free the internally held constants
458+
int num_constants = static_cast<int>(models_[0]->num_constants());
459+
std::shared_ptr<ConstantMap> to_free_map =
460+
use_secondary_ ? constants_map_ : constants_map_secondary_;
461+
462+
for (int i = 0; i < num_constants; i++) {
463+
if (models_[0]->constant_from_folded(i)) {
464+
auto it = to_free_map->find(models_[0]->constant_name(i));
465+
if (it != to_free_map->end()) {
466+
it->second.reset();
467+
}
468+
}
469+
}
456470
}
457471

458472
size_t num_inputs() const {

0 commit comments

Comments
 (0)