Skip to content

Commit 4c0cc77

Browse files
author
yingjieluan
committed
Sliding window KV Cache bug fix
1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length 2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset 3. Fix KV Cache import issue for gpt_with_kv_cache_optimized
1 parent a11965f commit 4c0cc77

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

ch04/03_kv-cache/gpt_with_kv_cache_optimized.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
3737
def forward(self, x, use_cache=False):
3838
b, num_tokens, d_in = x.shape
3939

40+
if use_cache:
41+
# to prevent self.ptr_cur became negative
42+
assert num_tokens <= self.window_size, (
43+
f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
44+
)
45+
4046
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
4147
values_new = self.W_value(x)
4248
queries = self.W_query(x)
@@ -221,6 +227,7 @@ def __init__(self, cfg):
221227

222228
self.final_norm = LayerNorm(cfg["emb_dim"])
223229
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
230+
self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]
224231

225232
def forward(self, in_idx, use_cache=False):
226233
batch_size, seq_len = in_idx.shape
@@ -232,6 +239,12 @@ def forward(self, in_idx, use_cache=False):
232239
# NEW
233240

234241
if use_cache:
242+
context_length = self.pos_emb.num_embeddings
243+
# to prevent generate more sequence than context_length
244+
# since longer than context_length will cause model out of bound error when reading the position embedding
245+
assert self.ptr_current_pos + seq_len <= context_length, (
246+
f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}"
247+
)
235248
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
236249
self.ptr_current_pos += seq_len
237250
else:
@@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u
294307
model.eval()
295308

296309
ctx_len = context_size or model.pos_emb.num_embeddings
310+
kv_window_size = model.kv_window_size
297311

298312
with torch.no_grad():
299313
if use_cache:
300314
model.reset_kv_cache()
301-
logits = model(idx[:, -ctx_len:], use_cache=True)
315+
316+
input_tokens = idx[:, -ctx_len:]
317+
input_tokens_length = input_tokens.size(1)
318+
319+
# prefill to handle input_tokens_length > kv_window_size
320+
for i in range(0, input_tokens_length, kv_window_size):
321+
chunk = input_tokens[:, i:i+kv_window_size]
322+
logits = model(chunk, use_cache=True)
323+
324+
# can't generate more than ctx_len of result
325+
# due to the limitation of position embedding
326+
max_generable = ctx_len - input_tokens_length
327+
max_new_tokens = min(max_new_tokens, max_generable)
302328

303329
for _ in range(max_new_tokens):
304330
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)

ch04/03_kv-cache/tests.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from gpt_with_kv_cache import GPTModel as GPTModelKV1
1111
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
12-
from gpt_with_kv_cache import generate_text_simple_cached
12+
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
13+
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2
1314

1415

1516
GPT_CONFIG_124M = {
@@ -20,6 +21,7 @@
2021
"n_layers": 12,
2122
"drop_rate": 0.1,
2223
"qkv_bias": False,
24+
"kv_window_size": 1024 # NEW: KV cache window size
2325
}
2426

2527

@@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass):
8082
max_new_tokens=30,
8183
context_size=GPT_CONFIG_124M["context_length"]
8284
)
85+
elif ModelClass is GPTModelKV1:
86+
token_ids = generate_text_simple_cachedKV1(
87+
model=model,
88+
idx=encoded_tensor,
89+
max_new_tokens=30,
90+
context_size=GPT_CONFIG_124M["context_length"]
91+
)
8392
else:
84-
token_ids = generate_text_simple_cached(
93+
token_ids = generate_text_simple_cachedKV2(
8594
model=model,
8695
idx=encoded_tensor,
8796
max_new_tokens=30,
@@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass):
99108
assert torch.equal(base_output, other_output), (
100109
f"Mismatch between {base_name} and {other_name}"
101110
)
111+
112+
113+
def test_context_overflow_bug():
114+
"""
115+
Test that demonstrates the ptr_current_pos overflow bug.
116+
117+
In old implementation:
118+
- context_length = 10 (positions 0-9 available)
119+
- We try to generate 15 tokens total (5 input + 10 generated)
120+
- At token 11 (position 10), it crashes trying to access pos_emb[10]
121+
"""
122+
GPT_CONFIG_SMALL = {
123+
"vocab_size": 50257,
124+
"context_length": 10, # Very small context
125+
"emb_dim": 768,
126+
"n_heads": 12,
127+
"n_layers": 12,
128+
"drop_rate": 0.1,
129+
"qkv_bias": False,
130+
"kv_window_size": 20 # Larger than context_length
131+
}
132+
133+
torch.manual_seed(123)
134+
135+
model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
136+
model.eval()
137+
138+
# 5 input tokens
139+
input_tokens = torch.randint(0, 50257, (1, 5), device=device)
140+
141+
generate_text_simple_cachedKV2(
142+
model=model,
143+
idx=input_tokens,
144+
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
145+
context_size=GPT_CONFIG_SMALL["context_length"],
146+
use_cache=True
147+
)
148+
149+
150+
def test_prefill_chunking_basic():
151+
"""
152+
Test that prefill correctly chunks input when input_length > kv_window_size.
153+
154+
Setup:
155+
- kv_window_size = 4
156+
- input_length = 10
157+
- Should process in 3 chunks: [0:4], [4:8], [8:10]
158+
"""
159+
config = {
160+
"vocab_size": 50257,
161+
"context_length": 20,
162+
"emb_dim": 768,
163+
"n_heads": 12,
164+
"n_layers": 12,
165+
"drop_rate": 0.1,
166+
"qkv_bias": False,
167+
"kv_window_size": 4 # Small window to force chunking
168+
}
169+
170+
torch.manual_seed(123)
171+
model = GPTModelKV2(config).to(device)
172+
model.eval()
173+
174+
# 10 input tokens (> kv_window_size of 4)
175+
input_tokens = torch.randint(0, 50257, (1, 10), device=device)
176+
177+
# Should successfully process all input in chunks
178+
token_ids = generate_text_simple_cachedKV2(
179+
model=model,
180+
idx=input_tokens,
181+
max_new_tokens=2,
182+
use_cache=True
183+
)
184+
185+
# Should have 10 input + 2 generated = 12 total
186+
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"
187+
188+
# First 10 tokens should match input
189+
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"

0 commit comments

Comments
 (0)