Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion ch04/03_kv-cache/gpt_with_kv_cache_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape

if use_cache:
# to prevent self.ptr_cur became negative
assert num_tokens <= self.window_size, (
f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
)

keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x)
queries = self.W_query(x)
Expand Down Expand Up @@ -221,6 +227,7 @@ def __init__(self, cfg):

self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]

def forward(self, in_idx, use_cache=False):
batch_size, seq_len = in_idx.shape
Expand All @@ -232,6 +239,12 @@ def forward(self, in_idx, use_cache=False):
# NEW

if use_cache:
context_length = self.pos_emb.num_embeddings
# to prevent generate more sequence than context_length
# since longer than context_length will cause model out of bound error when reading the position embedding
assert self.ptr_current_pos + seq_len <= context_length, (
f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}"
)
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.ptr_current_pos += seq_len
else:
Expand Down Expand Up @@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u
model.eval()

ctx_len = context_size or model.pos_emb.num_embeddings
kv_window_size = model.kv_window_size

with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)

input_tokens = idx[:, -ctx_len:]
input_tokens_length = input_tokens.size(1)

# prefill to handle input_tokens_length > kv_window_size
for i in range(0, input_tokens_length, kv_window_size):
chunk = input_tokens[:, i:i+kv_window_size]
logits = model(chunk, use_cache=True)

# can't generate more than ctx_len of result
# due to the limitation of position embedding
max_generable = ctx_len - input_tokens_length
max_new_tokens = min(max_new_tokens, max_generable)

for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
Expand Down
92 changes: 90 additions & 2 deletions ch04/03_kv-cache/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from gpt_with_kv_cache import GPTModel as GPTModelKV1
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
from gpt_with_kv_cache import generate_text_simple_cached
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2


GPT_CONFIG_124M = {
Expand All @@ -20,6 +21,7 @@
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 1024 # NEW: KV cache window size
}


Expand Down Expand Up @@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass):
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
elif ModelClass is GPTModelKV1:
token_ids = generate_text_simple_cachedKV1(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
else:
token_ids = generate_text_simple_cached(
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
Expand All @@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass):
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)


def test_context_overflow_bug():
"""
Test that demonstrates the ptr_current_pos overflow bug.

In old implementation:
- context_length = 10 (positions 0-9 available)
- We try to generate 15 tokens total (5 input + 10 generated)
- At token 11 (position 10), it crashes trying to access pos_emb[10]
"""
GPT_CONFIG_SMALL = {
"vocab_size": 50257,
"context_length": 10, # Very small context
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 20 # Larger than context_length
}

torch.manual_seed(123)

model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
model.eval()

# 5 input tokens
input_tokens = torch.randint(0, 50257, (1, 5), device=device)

generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
context_size=GPT_CONFIG_SMALL["context_length"],
use_cache=True
)


def test_prefill_chunking_basic():
"""
Test that prefill correctly chunks input when input_length > kv_window_size.

Setup:
- kv_window_size = 4
- input_length = 10
- Should process in 3 chunks: [0:4], [4:8], [8:10]
"""
config = {
"vocab_size": 50257,
"context_length": 20,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 4 # Small window to force chunking
}

torch.manual_seed(123)
model = GPTModelKV2(config).to(device)
model.eval()

# 10 input tokens (> kv_window_size of 4)
input_tokens = torch.randint(0, 50257, (1, 10), device=device)

# Should successfully process all input in chunks
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=2,
use_cache=True
)

# Should have 10 input + 2 generated = 12 total
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"

# First 10 tokens should match input
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"