File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -1164,8 +1164,18 @@ def _verify_quantization(self) -> None:
1164
1164
"non-quantized models." , self .quantization )
1165
1165
1166
1166
def _verify_cuda_graph (self ) -> None :
1167
+ # The `max_seq_len_to_capture` was incorrectly
1168
+ # based on the encoder's input length (448)
1169
+ # but not the decoder's larger input length (1500).
1170
+ # This change ensures the CUDA Graph captures the correct,
1171
+ # larger sequence length, allowing it to work as intended.
1172
+ effective_max_seq_len = self .max_model_len
1173
+ if self .is_encoder_decoder :
1174
+ effective_max_seq_len = max (
1175
+ effective_max_seq_len ,
1176
+ getattr (self .hf_config , "max_source_positions" , 0 ))
1167
1177
self .max_seq_len_to_capture = min (self .max_seq_len_to_capture ,
1168
- self . max_model_len )
1178
+ effective_max_seq_len )
1169
1179
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
1170
1180
ROCM_UNSUPPORTED_MODELS = ['mllama' ]
1171
1181
unsupported_rocm = (self .hf_config .model_type
You can’t perform that action at this time.
0 commit comments