Skip to content

Commit cc23043

Browse files
jessegrossrick-github
authored andcommitted
kvcache: Don't shift empty batches
When we context shift, we delete half the context and apply RoPE with an offset to the other half. We used to RoPE across the entire context in a single pass with a zero offset for the deleted section. With the change to shifting in batches, we can skip any batches where all of the offsets would be zero. This typically reduces the number of operations by half.
1 parent 924e1eb commit cc23043

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

kvcache/causal.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
646646
seqRange := c.cellRanges[seq]
647647

648648
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
649-
ctx := c.backend.NewContext()
650-
651649
size := min(seqRange.max-start+1, c.maxBatch)
652650
offsets := make([]int32, size)
651+
652+
var batchFirst, batchLast int
653+
654+
batchFirst = -1
653655
for i := range offsets {
654656
cell := c.cells[start+i]
655657

656658
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
657659
offsets[i] = offset
660+
if batchFirst < 0 {
661+
batchFirst = i
662+
}
663+
batchLast = i
658664
}
659665
}
660666

667+
if batchFirst < 0 {
668+
continue
669+
}
670+
671+
offsets = offsets[batchFirst : batchLast+1]
672+
673+
ctx := c.backend.NewContext()
661674
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
662675

663676
for i, key := range c.keys {
@@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
669682
numKVHeads := key.Dim(1)
670683
rowSize := key.Stride(2)
671684

672-
key = key.View(ctx, rowSize*start,
685+
key = key.View(ctx, rowSize*(start+batchFirst),
673686
kHeadDim, key.Stride(1),
674687
numKVHeads, key.Stride(2),
675-
size,
688+
len(offsets),
676689
)
677690

678691
roped, err := c.shiftFn(ctx, i, key, kShift)

0 commit comments

Comments
 (0)