File tree Expand file tree Collapse file tree 1 file changed +9
-9
lines changed
Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Original file line number Diff line number Diff line change 55import sys , gc
66
77def clear_mem ():
8-
98 # clear vram (GPU)
109 backend = jax .lib .xla_bridge .get_backend ()
1110 if hasattr (backend ,'live_buffers' ):
@@ -16,15 +15,16 @@ def clear_mem():
1615 # https://github.com/google/jax/issues/10828
1716 for module_name , module in sys .modules .items ():
1817 if module_name .startswith ("jax" ):
19- for obj_name in dir (module ):
20- obj = getattr (module , obj_name )
21- if hasattr (obj , "cache_clear" ):
22- try :
23- obj .cache_clear ()
24- except :
25- pass
18+ if module_name not in ["jax.interpreters.partial_eval" ]:
19+ for obj_name in dir (module ):
20+ obj = getattr (module , obj_name )
21+ if hasattr (obj , "cache_clear" ):
22+ try :
23+ obj .cache_clear ()
24+ except :
25+ pass
2626 gc .collect ()
27-
27+
2828def update_dict (D , * args , ** kwargs ):
2929 '''robust function for updating dictionary'''
3030 def set_dict (d , x , override = False ):
You can’t perform that action at this time.
0 commit comments