@@ -156,9 +156,13 @@ def forward(
156
156
kv_cache : torch .Tensor ,
157
157
attn_metadata : AttentionMetadata ,
158
158
) -> torch .Tensor :
159
- if self .calculate_kv_scales and \
160
- attn_metadata .enable_kv_scales_calculation :
161
- self .calc_kv_scales (key , value )
159
+ # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
160
+ # directly, use `self.kv_cache` and
161
+ # `get_forward_context().attn_metadata` instead.
162
+ if self .calculate_kv_scales :
163
+ ctx_attn_metadata = get_forward_context ().attn_metadata
164
+ if ctx_attn_metadata .enable_kv_scales_calculation :
165
+ self .calc_kv_scales (key , value )
162
166
if self .use_output :
163
167
output = torch .empty_like (query )
164
168
hidden_size = query .size (- 1 )
@@ -172,15 +176,27 @@ def forward(
172
176
if value is not None :
173
177
value = value .view (- 1 , self .num_kv_heads , self .head_size )
174
178
if self .use_direct_call :
175
- unified_attention_with_output (query , key , value , output ,
176
- self .layer_name )
179
+ forward_context : ForwardContext = get_forward_context ()
180
+ ctx_attn_metadata = forward_context .attn_metadata
181
+ self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
182
+ self .impl .forward (self ,
183
+ query ,
184
+ key ,
185
+ value ,
186
+ self_kv_cache ,
187
+ ctx_attn_metadata ,
188
+ output = output )
177
189
else :
178
190
torch .ops .vllm .unified_attention_with_output (
179
191
query , key , value , output , self .layer_name )
180
192
return output .view (- 1 , hidden_size )
181
193
else :
182
194
if self .use_direct_call :
183
- return unified_attention (query , key , value , self .layer_name )
195
+ forward_context = get_forward_context ()
196
+ ctx_attn_metadata = forward_context .attn_metadata
197
+ self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
198
+ return self .impl .forward (self , query , key , value ,
199
+ self_kv_cache , ctx_attn_metadata )
184
200
else :
185
201
return torch .ops .vllm .unified_attention (
186
202
query , key , value , self .layer_name )
0 commit comments