|
20 | 20 | import os |
21 | 21 | import re |
22 | 22 | import shutil |
23 | | -import signal |
24 | 23 | import sys |
25 | 24 | import threading |
26 | 25 | from pathlib import Path |
|
34 | 33 |
|
35 | 34 | from .. import __version__ |
36 | 35 | from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging |
| 36 | +from .constants import DIFFUSERS_DISABLE_REMOTE_CODE |
37 | 37 |
|
38 | 38 |
|
39 | 39 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -159,52 +159,25 @@ def check_imports(filename): |
159 | 159 | return get_relative_imports(filename) |
160 | 160 |
|
161 | 161 |
|
162 | | -def _raise_timeout_error(signum, frame): |
163 | | - raise ValueError( |
164 | | - "Loading this model requires you to execute custom code contained in the model repository on your local " |
165 | | - "machine. Please set the option `trust_remote_code=True` to permit loading of this model." |
166 | | - ) |
167 | | - |
168 | | - |
169 | 162 | def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): |
170 | | - if trust_remote_code is None: |
171 | | - if has_remote_code and TIME_OUT_REMOTE_CODE > 0: |
172 | | - prev_sig_handler = None |
173 | | - try: |
174 | | - prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) |
175 | | - signal.alarm(TIME_OUT_REMOTE_CODE) |
176 | | - while trust_remote_code is None: |
177 | | - answer = input( |
178 | | - f"The repository for {model_name} contains custom code which must be executed to correctly " |
179 | | - f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" |
180 | | - f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" |
181 | | - f"Do you wish to run the custom code? [y/N] " |
182 | | - ) |
183 | | - if answer.lower() in ["yes", "y", "1"]: |
184 | | - trust_remote_code = True |
185 | | - elif answer.lower() in ["no", "n", "0", ""]: |
186 | | - trust_remote_code = False |
187 | | - signal.alarm(0) |
188 | | - except Exception: |
189 | | - # OS which does not support signal.SIGALRM |
190 | | - raise ValueError( |
191 | | - f"The repository for {model_name} contains custom code which must be executed to correctly " |
192 | | - f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" |
193 | | - f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." |
194 | | - ) |
195 | | - finally: |
196 | | - if prev_sig_handler is not None: |
197 | | - signal.signal(signal.SIGALRM, prev_sig_handler) |
198 | | - signal.alarm(0) |
199 | | - elif has_remote_code: |
200 | | - # For the CI which puts the timeout at 0 |
201 | | - _raise_timeout_error(None, None) |
| 163 | + trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE |
| 164 | + if DIFFUSERS_DISABLE_REMOTE_CODE: |
| 165 | + logger.warning( |
| 166 | + "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`." |
| 167 | + ) |
202 | 168 |
|
203 | 169 | if has_remote_code and not trust_remote_code: |
204 | | - raise ValueError( |
205 | | - f"Loading {model_name} requires you to execute the configuration file in that" |
206 | | - " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" |
207 | | - " set the option `trust_remote_code=True` to remove this error." |
| 170 | + error_msg = f"The repository for {model_name} contains custom code. " |
| 171 | + error_msg += ( |
| 172 | + "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable." |
| 173 | + if DIFFUSERS_DISABLE_REMOTE_CODE |
| 174 | + else "Pass `trust_remote_code=True` to allow loading remote code modules." |
| 175 | + ) |
| 176 | + raise ValueError(error_msg) |
| 177 | + |
| 178 | + elif has_remote_code and trust_remote_code: |
| 179 | + logger.warning( |
| 180 | + f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository" |
208 | 181 | ) |
209 | 182 |
|
210 | 183 | return trust_remote_code |
|
0 commit comments