Skip to content

Commit c90c4d3

Browse files
authored
bugfix for latest jax
1 parent 9494856 commit c90c4d3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

colabdesign/shared/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ def clear_mem():
1919
for obj_name in dir(module):
2020
obj = getattr(module, obj_name)
2121
if hasattr(obj, "cache_clear"):
22-
obj.cache_clear()
22+
try:
23+
obj.cache_clear()
24+
except:
25+
pass
2326
gc.collect()
2427

2528
def update_dict(D, *args, **kwargs):
@@ -105,4 +108,4 @@ def softmax(x, axis=-1):
105108
return x / x.sum(axis,keepdims=True)
106109

107110
def categorical(p):
108-
return (p.cumsum(-1) >= np.random.uniform(size=p.shape[:-1])[..., None]).argmax(-1)
111+
return (p.cumsum(-1) >= np.random.uniform(size=p.shape[:-1])[..., None]).argmax(-1)

0 commit comments

Comments
 (0)