|
80 | 80 | "We also have some installation instructions on our Github page." |
81 | 81 | ) |
82 | 82 |
|
| 83 | +# Keep original allocator settings to preserve explicit user config precedence. |
| 84 | +_ORIGINAL_PYTORCH_CUDA_ALLOC_CONF = os.environ.get("PYTORCH_CUDA_ALLOC_CONF") |
| 85 | +_ORIGINAL_PYTORCH_HIP_ALLOC_CONF = os.environ.get("PYTORCH_HIP_ALLOC_CONF") |
| 86 | +_HAS_ORIGINAL_PYTORCH_ALLOC_CONF = "PYTORCH_ALLOC_CONF" in os.environ |
| 87 | + |
83 | 88 | # Reduce VRAM usage by reducing fragmentation |
84 | 89 | # And optimize pinning of memory |
85 | 90 | if os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0": |
@@ -132,6 +137,20 @@ def remove_expandable_segments(key): |
132 | 137 | delete_key(key) |
133 | 138 |
|
134 | 139 |
|
| 140 | +def clean_expandable_segments_value(value): |
| 141 | + if value is None or "expandable_segments" not in value: |
| 142 | + return value |
| 143 | + parts = [] |
| 144 | + for part in value.split(","): |
| 145 | + part = part.strip() |
| 146 | + if not part: |
| 147 | + continue |
| 148 | + if part.startswith("expandable_segments:"): |
| 149 | + continue |
| 150 | + parts.append(part) |
| 151 | + return ",".join(parts) if len(parts) else None |
| 152 | + |
| 153 | + |
135 | 154 | if (major_torch < 2): |
136 | 155 | raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\ |
137 | 156 | "We have some installation instructions on our Github page.") |
@@ -186,10 +205,18 @@ def filter(self, x): return not (self.text in x.getMessage()) |
186 | 205 |
|
187 | 206 | # Torch 2.9 removed PYTORCH_HIP_ALLOC_CONF and PYTORCH_CUDA_ALLOC_CONF |
188 | 207 | if major_torch == 2 and minor_torch >= 9: |
189 | | - for key in ("PYTORCH_CUDA_ALLOC_CONF", "PYTORCH_HIP_ALLOC_CONF",): |
190 | | - if key in os.environ and "PYTORCH_ALLOC_CONF" not in os.environ: |
191 | | - os.environ["PYTORCH_ALLOC_CONF"] = os.environ[key] |
192 | | - delete_key(key) |
| 208 | + # Preserve explicit legacy allocator settings when user did not directly set PYTORCH_ALLOC_CONF. |
| 209 | + if not _HAS_ORIGINAL_PYTORCH_ALLOC_CONF: |
| 210 | + promoted = _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF |
| 211 | + if promoted is None: |
| 212 | + promoted = _ORIGINAL_PYTORCH_HIP_ALLOC_CONF |
| 213 | + # Keep standby + ROCm protections when promoting legacy values. |
| 214 | + if os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "1" or IS_TORCH_ROCM_BUILD: |
| 215 | + promoted = clean_expandable_segments_value(promoted) |
| 216 | + if promoted is not None: |
| 217 | + os.environ["PYTORCH_ALLOC_CONF"] = promoted |
| 218 | + delete_key("PYTORCH_CUDA_ALLOC_CONF") |
| 219 | + delete_key("PYTORCH_HIP_ALLOC_CONF") |
193 | 220 |
|
194 | 221 | # Specify PYTORCH_CUDA_ALLOC_CONF or PYTORCH_HIP_ALLOC_CONF |
195 | 222 | if IS_HIP_RUNTIME: |
@@ -221,6 +248,8 @@ def filter(self, x): return not (self.text in x.getMessage()) |
221 | 248 | # CCE also fails in HIP / AMD |
222 | 249 | os.environ["UNSLOTH_ENABLE_CCE"] = "0" |
223 | 250 | del remove_expandable_segments, delete_key, IS_HIP_RUNTIME, IS_TORCH_ROCM_BUILD, major_torch, minor_torch, torch_version, torch_version_raw, importlib_version, find_spec |
| 251 | +del clean_expandable_segments_value |
| 252 | +del _ORIGINAL_PYTORCH_CUDA_ALLOC_CONF, _ORIGINAL_PYTORCH_HIP_ALLOC_CONF, _HAS_ORIGINAL_PYTORCH_ALLOC_CONF |
224 | 253 |
|
225 | 254 | if not ("UNSLOTH_IS_PRESENT" in os.environ): |
226 | 255 | raise ImportError("Please install Unsloth via `pip install unsloth`!") |
|
0 commit comments