Skip to content

Commit b6dbba1

Browse files
Fix torch 2.9+ allocator config precedence for user env vars (#501)
* Preserve user allocator config precedence on torch 2.9+ * Fix standby and ROCm behavior in allocator promotion --------- Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
1 parent 58303bf commit b6dbba1

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

unsloth_zoo/__init__.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
"We also have some installation instructions on our Github page."
8181
)
8282

83+
# Keep original allocator settings to preserve explicit user config precedence.
84+
_ORIGINAL_PYTORCH_CUDA_ALLOC_CONF = os.environ.get("PYTORCH_CUDA_ALLOC_CONF")
85+
_ORIGINAL_PYTORCH_HIP_ALLOC_CONF = os.environ.get("PYTORCH_HIP_ALLOC_CONF")
86+
_HAS_ORIGINAL_PYTORCH_ALLOC_CONF = "PYTORCH_ALLOC_CONF" in os.environ
87+
8388
# Reduce VRAM usage by reducing fragmentation
8489
# And optimize pinning of memory
8590
if os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0":
@@ -132,6 +137,20 @@ def remove_expandable_segments(key):
132137
delete_key(key)
133138

134139

140+
def clean_expandable_segments_value(value):
141+
if value is None or "expandable_segments" not in value:
142+
return value
143+
parts = []
144+
for part in value.split(","):
145+
part = part.strip()
146+
if not part:
147+
continue
148+
if part.startswith("expandable_segments:"):
149+
continue
150+
parts.append(part)
151+
return ",".join(parts) if len(parts) else None
152+
153+
135154
if (major_torch < 2):
136155
raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
137156
"We have some installation instructions on our Github page.")
@@ -186,10 +205,18 @@ def filter(self, x): return not (self.text in x.getMessage())
186205

187206
# Torch 2.9 removed PYTORCH_HIP_ALLOC_CONF and PYTORCH_CUDA_ALLOC_CONF
188207
if major_torch == 2 and minor_torch >= 9:
189-
for key in ("PYTORCH_CUDA_ALLOC_CONF", "PYTORCH_HIP_ALLOC_CONF",):
190-
if key in os.environ and "PYTORCH_ALLOC_CONF" not in os.environ:
191-
os.environ["PYTORCH_ALLOC_CONF"] = os.environ[key]
192-
delete_key(key)
208+
# Preserve explicit legacy allocator settings when user did not directly set PYTORCH_ALLOC_CONF.
209+
if not _HAS_ORIGINAL_PYTORCH_ALLOC_CONF:
210+
promoted = _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF
211+
if promoted is None:
212+
promoted = _ORIGINAL_PYTORCH_HIP_ALLOC_CONF
213+
# Keep standby + ROCm protections when promoting legacy values.
214+
if os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "1" or IS_TORCH_ROCM_BUILD:
215+
promoted = clean_expandable_segments_value(promoted)
216+
if promoted is not None:
217+
os.environ["PYTORCH_ALLOC_CONF"] = promoted
218+
delete_key("PYTORCH_CUDA_ALLOC_CONF")
219+
delete_key("PYTORCH_HIP_ALLOC_CONF")
193220

194221
# Specify PYTORCH_CUDA_ALLOC_CONF or PYTORCH_HIP_ALLOC_CONF
195222
if IS_HIP_RUNTIME:
@@ -221,6 +248,8 @@ def filter(self, x): return not (self.text in x.getMessage())
221248
# CCE also fails in HIP / AMD
222249
os.environ["UNSLOTH_ENABLE_CCE"] = "0"
223250
del remove_expandable_segments, delete_key, IS_HIP_RUNTIME, IS_TORCH_ROCM_BUILD, major_torch, minor_torch, torch_version, torch_version_raw, importlib_version, find_spec
251+
del clean_expandable_segments_value
252+
del _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF, _ORIGINAL_PYTORCH_HIP_ALLOC_CONF, _HAS_ORIGINAL_PYTORCH_ALLOC_CONF
224253

225254
if not ("UNSLOTH_IS_PRESENT" in os.environ):
226255
raise ImportError("Please install Unsloth via `pip install unsloth`!")

0 commit comments

Comments
 (0)