We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b85f89c commit ad7f5c6Copy full SHA for ad7f5c6
specforge/modeling/draft/llama3_eagle.py
@@ -6,7 +6,6 @@
6
import torch.distributed as dist
7
import torch.nn as nn
8
import torch.nn.functional as F
9
-from flash_attn import flash_attn_func
10
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
11
from transformers.activations import ACT2FN
12
from transformers.cache_utils import Cache
@@ -26,9 +25,10 @@
26
25
27
try:
28
from flash_attn import flash_attn_func
29
-except:
+except ImportError:
30
warnings.warn(
31
- "flash_attn is not found, please install flash_attn if you want to use the flash attention backend"
+ "flash_attn is not found, falling back to flex_attention. "
+ "Please install flash_attn if you want to use the flash attention backend."
32
)
33
flash_attn_func = None
34
0 commit comments