Skip to content

Commit 9e83f31

Browse files
authored
Update utils.py
1 parent 2cb66db commit 9e83f31

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

colabdesign/shared/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys, gc
66

77
def clear_mem():
8-
98
# clear vram (GPU)
109
backend = jax.lib.xla_bridge.get_backend()
1110
if hasattr(backend,'live_buffers'):
@@ -16,15 +15,16 @@ def clear_mem():
1615
# https://github.com/google/jax/issues/10828
1716
for module_name, module in sys.modules.items():
1817
if module_name.startswith("jax"):
19-
for obj_name in dir(module):
20-
obj = getattr(module, obj_name)
21-
if hasattr(obj, "cache_clear"):
22-
try:
23-
obj.cache_clear()
24-
except:
25-
pass
18+
if module_name not in ["jax.interpreters.partial_eval"]:
19+
for obj_name in dir(module):
20+
obj = getattr(module, obj_name)
21+
if hasattr(obj, "cache_clear"):
22+
try:
23+
obj.cache_clear()
24+
except:
25+
pass
2626
gc.collect()
27-
27+
2828
def update_dict(D, *args, **kwargs):
2929
'''robust function for updating dictionary'''
3030
def set_dict(d, x, override=False):

0 commit comments

Comments
 (0)