Skip to content

Commit 4c3aac5

Browse files
authored
Merging PR #12536
Merged via CLI script
1 parent bc1bdec commit 4c3aac5

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

vllm/attention/layer.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,13 @@ def forward(
156156
kv_cache: torch.Tensor,
157157
attn_metadata: AttentionMetadata,
158158
) -> torch.Tensor:
159-
if self.calculate_kv_scales and \
160-
attn_metadata.enable_kv_scales_calculation:
161-
self.calc_kv_scales(key, value)
159+
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
160+
# directly, use `self.kv_cache` and
161+
# `get_forward_context().attn_metadata` instead.
162+
if self.calculate_kv_scales:
163+
ctx_attn_metadata = get_forward_context().attn_metadata
164+
if ctx_attn_metadata.enable_kv_scales_calculation:
165+
self.calc_kv_scales(key, value)
162166
if self.use_output:
163167
output = torch.empty_like(query)
164168
hidden_size = query.size(-1)
@@ -172,15 +176,27 @@ def forward(
172176
if value is not None:
173177
value = value.view(-1, self.num_kv_heads, self.head_size)
174178
if self.use_direct_call:
175-
unified_attention_with_output(query, key, value, output,
176-
self.layer_name)
179+
forward_context: ForwardContext = get_forward_context()
180+
ctx_attn_metadata = forward_context.attn_metadata
181+
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
182+
self.impl.forward(self,
183+
query,
184+
key,
185+
value,
186+
self_kv_cache,
187+
ctx_attn_metadata,
188+
output=output)
177189
else:
178190
torch.ops.vllm.unified_attention_with_output(
179191
query, key, value, output, self.layer_name)
180192
return output.view(-1, hidden_size)
181193
else:
182194
if self.use_direct_call:
183-
return unified_attention(query, key, value, self.layer_name)
195+
forward_context = get_forward_context()
196+
ctx_attn_metadata = forward_context.attn_metadata
197+
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
198+
return self.impl.forward(self, query, key, value,
199+
self_kv_cache, ctx_attn_metadata)
184200
else:
185201
return torch.ops.vllm.unified_attention(
186202
query, key, value, self.layer_name)

0 commit comments

Comments
 (0)