Skip to content

Commit ad7f5c6

Browse files
authored
fix attn import (#436)
1 parent b85f89c commit ad7f5c6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

specforge/modeling/draft/llama3_eagle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch.distributed as dist
77
import torch.nn as nn
88
import torch.nn.functional as F
9-
from flash_attn import flash_attn_func
109
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
1110
from transformers.activations import ACT2FN
1211
from transformers.cache_utils import Cache
@@ -26,9 +25,10 @@
2625

2726
try:
2827
from flash_attn import flash_attn_func
29-
except:
28+
except ImportError:
3029
warnings.warn(
31-
"flash_attn is not found, please install flash_attn if you want to use the flash attention backend"
30+
"flash_attn is not found, falling back to flex_attention. "
31+
"Please install flash_attn if you want to use the flash attention backend."
3232
)
3333
flash_attn_func = None
3434

0 commit comments

Comments
 (0)