Skip to content

Commit a043f45

Browse files
authored
Fix pytorch memory curve estimation for rmm backed allocator (#94)
* fix pytorch memory curve estimation Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com> * Add test Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com> * Add test for rmm Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com> * move imports Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com> * Fix based on Praateeks review --------- Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com>
1 parent d7e2643 commit a043f45

File tree

3 files changed

+90
-5
lines changed

3 files changed

+90
-5
lines changed

crossfit/backend/torch/hf/memory_curve_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from transformers import PreTrainedModel
2222

2323
from crossfit.utils.model_adapter import adapt_model_input
24-
from crossfit.utils.torch_utils import cleanup_torch_cache
24+
from crossfit.utils.torch_utils import (
25+
cleanup_torch_cache,
26+
get_peak_memory_used,
27+
reset_memory_tracking,
28+
)
2529

2630

2731
def fit_memory_estimate_curve(
@@ -37,7 +41,7 @@ def fit_memory_estimate_curve(
3741
) -> LinearRegression:
3842
print(f"Fitting memory estimate curve for model: {path_or_name}")
3943

40-
device = next(model.parameters()).device
44+
device = "cuda"
4145
X: list[list[int]] = []
4246
y: list[float] = []
4347

@@ -51,16 +55,16 @@ def fit_memory_estimate_curve(
5155
leave=False,
5256
)
5357
for seq_len in seq_len_pbar:
54-
torch.cuda.reset_peak_memory_stats()
55-
58+
reset_memory_tracking()
5659
batch = {
5760
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
5861
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
5962
}
6063

6164
try:
6265
_ = adapt_model_input(model, batch)
63-
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
66+
memory_used = get_peak_memory_used()
67+
memory_used = memory_used / (1024**2) # Convert to MB
6468
X.append([batch_size, seq_len, seq_len**2])
6569
y.append(memory_used)
6670

crossfit/utils/torch_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,55 @@ def cleanup_torch_cache() -> None:
9999
gc.collect()
100100
torch.cuda.empty_cache()
101101
return None
102+
103+
104+
def reset_memory_tracking() -> None:
105+
"""
106+
Resets memory counters.
107+
108+
This function enables memory usage statistics tracking and resets the counters
109+
for peak memory usage. It handles both RMM (RAPIDS Memory Manager) and PyTorch's
110+
native CUDA memory tracking, depending on the current allocator backend.
111+
112+
If RMM is being used as the allocator, it enables RMM statistics and pushes a new
113+
statistics context. If the default PyTorch allocator is being used, it resets the
114+
peak memory stats for CUDA.
115+
116+
Returns:
117+
None
118+
"""
119+
if is_torch_memory_rmm():
120+
import rmm
121+
122+
rmm.statistics.enable_statistics()
123+
rmm.statistics.push_statistics()
124+
else:
125+
torch.cuda.reset_peak_memory_stats()
126+
127+
128+
def get_peak_memory_used() -> int:
129+
"""
130+
Get the peak memory usage in bytes.
131+
132+
This function retrieves the peak memory usage, either from RMM statistics
133+
if the RMM allocator is being used, or from PyTorch's CUDA memory stats.
134+
135+
Returns:
136+
int: Peak memory usage in bytes.
137+
"""
138+
if is_torch_memory_rmm():
139+
import rmm
140+
141+
stats = rmm.statistics.pop_statistics()
142+
return stats.peak_bytes
143+
else:
144+
return torch.cuda.max_memory_allocated()
145+
146+
147+
def is_torch_memory_rmm():
148+
# TODO: This is hacky, we need to check if the allocator is rmm
149+
# and then reset the peak memory stats
150+
# we get this fixed in Pytorch
151+
# https://github.com/pytorch/pytorch/issues/133281
152+
# https://github.com/pytorch/pytorch/issues/133280
153+
return torch.cuda.memory.get_allocator_backend() == "pluggable"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from sklearn.linear_model import LinearRegression
3+
4+
transformers = pytest.importorskip("transformers")
5+
torch = pytest.importorskip("torch")
6+
rmm_torch_allocator = pytest.importorskip("rmm.allocators.torch").rmm_torch_allocator
7+
fit_memory_estimate_curve = pytest.importorskip(
8+
"crossfit.backend.torch.hf.memory_curve_utils"
9+
).fit_memory_estimate_curve
10+
11+
MODEL_NAME = "microsoft/deberta-v3-base"
12+
13+
# Have to do it globally
14+
# TODO: Long term figure out a better way
15+
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
16+
17+
18+
def test_fit_memory_estimate_curve(tmp_path):
19+
# Setup
20+
mem_model_path = tmp_path / "test_memory_model.joblib"
21+
model = transformers.AutoModel.from_pretrained(MODEL_NAME).to("cuda")
22+
result = fit_memory_estimate_curve(
23+
model=model, path_or_name=MODEL_NAME, mem_model_path=str(mem_model_path)
24+
)
25+
# Assertions
26+
assert isinstance(result, LinearRegression)
27+
assert result.coef_.shape == (3,) # [batch_size, seq_len, seq_len**2]
28+
assert isinstance(result.intercept_, float)
29+
assert mem_model_path.exists()

0 commit comments

Comments
 (0)