Skip to content

Commit 3b0f393

Browse files
Revert "Adds snapshot API for MemPools to get pool memory segments (pytorch#133601)"
This reverts commit 00504aa. Reverted pytorch#133601 on behalf of https://github.com/wdvr due to reverting for now as this breaks lots of internal tests. Details below ([comment](pytorch#133601 (comment)))
1 parent 5916def commit 3b0f393

File tree

3 files changed

+24
-85
lines changed

3 files changed

+24
-85
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,41 +1897,16 @@ class DeviceCachingAllocator {
18971897

18981898
std::unordered_map<PrivatePool*, MempoolId_t> pool_to_id;
18991899
pool_to_id.reserve(graph_pools.size() + graph_pools_freeable.size());
1900-
std::vector<Block*> all_blocks;
1901-
MempoolId_t mempool_id = {0, 0};
1902-
1903-
auto active_mempool = MemPoolContext::getActiveMemPool();
1904-
if (active_mempool) {
1905-
mempool_id = active_mempool->id();
1900+
for (const auto& pair : graph_pools) {
1901+
pool_to_id[pair.second.get()] = pair.first;
19061902
}
1907-
1908-
if (mempool_id.first != 0 || mempool_id.second != 0) {
1909-
// If there is an active mempool, we find the corresponding PrivatePool
1910-
// in graph_pools and only return the blocks from it.
1911-
auto pool = graph_pools.find(mempool_id);
1912-
if (pool != graph_pools.end()) {
1913-
pool_to_id[pool->second.get()] = pool->first;
1914-
all_blocks = get_private_pool_head_blocks(pool->second.get());
1915-
}
1916-
auto pool_freeable = graph_pools_freeable.find(mempool_id);
1917-
if (pool_freeable != graph_pools_freeable.end()) {
1918-
pool_to_id[pool_freeable->second] = pool_freeable->first;
1919-
}
1920-
} else {
1921-
// When snapshot is called outside a MemPoolContext, we return
1922-
// all the blocks in the CUDACachingAllocator (as returned by
1923-
// get_all_blocks).
1924-
for (const auto& pair : graph_pools) {
1925-
pool_to_id[pair.second.get()] = pair.first;
1926-
}
1927-
for (const auto& pair : graph_pools_freeable) {
1928-
pool_to_id[pair.second] = pair.first;
1929-
}
1930-
all_blocks = get_all_blocks();
1903+
for (const auto& pair : graph_pools_freeable) {
1904+
pool_to_id[pair.second] = pair.first;
19311905
}
19321906

19331907
size_t total_active = 0;
19341908
std::vector<SegmentInfo> result;
1909+
const auto all_blocks = get_all_blocks();
19351910

19361911
for (const Block* const head_block : all_blocks) {
19371912
// For expandable segments, we report one segment for each contiguous
@@ -2134,8 +2109,8 @@ class DeviceCachingAllocator {
21342109
private:
21352110
// All private methods do not acquire the allocator mutex.
21362111

2137-
std::vector<Block*> get_all_blocks() const {
2138-
std::vector<Block*> blocks;
2112+
std::vector<const Block*> get_all_blocks() const {
2113+
std::vector<const Block*> blocks;
21392114
blocks.insert(
21402115
blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
21412116
blocks.insert(

test/test_cuda.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4542,10 +4542,9 @@ def test_mempool_with_allocator(self):
45424542
alloc_lib = ctypes.CDLL(dummy_allocator)
45434543
called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc")
45444544
self.assertEqual(called_dummy_alloc.value, 0)
4545-
nelem_1mb = 1024 * 1024 // 4
45464545

45474546
with torch.cuda.use_mem_pool(pool):
4548-
out_0 = torch.randn(nelem_1mb, device="cuda")
4547+
out_0 = torch.randn(1, device="cuda")
45494548

45504549
# pool's use count should be 2 at this point as use_mem_pool
45514550
# holds a reference
@@ -4559,23 +4558,6 @@ def test_mempool_with_allocator(self):
45594558
# out tensor
45604559
self.assertEqual(called_dummy_alloc.value, 123)
45614560

4562-
with torch.cuda.use_mem_pool(pool):
4563-
# pool should have 1 segment since we made a small allocation (1 MB)
4564-
# above and so the CUDACachingAllocator packed it into a 2 MB buffer
4565-
self.assertEqual(len(pool.snapshot()), 1)
4566-
4567-
out_1 = torch.randn(nelem_1mb, device="cuda")
4568-
4569-
# pool should still have 1 segment since we made another small allocation
4570-
# (1 MB) that got packed into the existing 2 MB buffer
4571-
self.assertEqual(len(pool.snapshot()), 1)
4572-
4573-
out_2 = torch.randn(nelem_1mb, device="cuda")
4574-
4575-
# pool now should have 2 segments since the CUDACachingAllocator had
4576-
# to make a new 2 MB buffer to accomodate out_2
4577-
self.assertEqual(len(pool.snapshot()), 2)
4578-
45794561
def test_mempool_context(self):
45804562
active_pool = torch.cuda.MemPoolContext.active_pool()
45814563

torch/cuda/memory.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -980,25 +980,6 @@ def _get_current_allocator() -> _CUDAAllocator:
980980
return _CUDAAllocator(torch._C._cuda_getAllocator())
981981

982982

983-
class MemPoolContext(_MemPoolContext):
984-
r"""MemPoolContext holds the currently active pool and stashes the previous
985-
pool. On deletion it makes the previous pool active.
986-
987-
Args:
988-
pool(torch.cuda.MemPool): a MemPool object to be made active so that
989-
allocations route to this pool.
990-
991-
"""
992-
993-
def __init__(self, pool: _MemPool):
994-
super().__init__(pool)
995-
996-
@staticmethod
997-
def active_pool() -> Optional[_MemPool]:
998-
r"""Returns the active MemPool"""
999-
return _MemPoolContext.active_pool()
1000-
1001-
1002983
class MemPool(_MemPool):
1003984
r"""MemPool represents a pool of memory in a caching allocator. Currently,
1004985
it's just the ID of the pool object maintained in the CUDACachingAllocator.
@@ -1029,23 +1010,24 @@ def use_count(self) -> int:
10291010
r"""Returns the reference count of this pool."""
10301011
return super().use_count()
10311012

1032-
def snapshot(self):
1033-
r"""Return a snapshot of the CUDA memory allocator pool state across all
1034-
devices.
10351013

1036-
Interpreting the output of this function requires familiarity with the
1037-
memory allocator internals.
1014+
class MemPoolContext(_MemPoolContext):
1015+
r"""MemPoolContext holds the currently active pool and stashes the previous
1016+
pool. On deletion it makes the previous pool active.
10381017
1039-
.. note::
1040-
See :ref:`cuda-memory-management` for more details about GPU memory
1041-
management.
1042-
"""
1043-
try:
1044-
ctx = MemPoolContext(self)
1045-
snapshot = torch.cuda.memory_snapshot()
1046-
finally:
1047-
del ctx
1048-
return snapshot
1018+
Args:
1019+
pool(torch.cuda.MemPool): a MemPool object to be made active so that
1020+
allocations route to this pool.
1021+
1022+
"""
1023+
1024+
def __init__(self, pool: MemPool):
1025+
super().__init__(pool)
1026+
1027+
@staticmethod
1028+
def active_pool() -> Optional[_MemPool]:
1029+
r"""Returns the active MemPool"""
1030+
return _MemPoolContext.active_pool()
10491031

10501032

10511033
@contextlib.contextmanager

0 commit comments

Comments
 (0)