Skip to content

Commit ae0fd32

Browse files
jessegrossrick-github
authored andcommitted
kvcache: Group shift operations into batches
Currently, when we need to do a shift on the cache, it is one RoPE operation on the entire size of the cache (per layer). In some cases, this can create a compute graph that is larger than the forward pass since the forward pass is working in batches. Since we don't consider shifting in our memory estimates, it's possible for this to cause a crash if we run out of memory. By limiting the size of the RoPE calls to batch size chunks, we ensure that the shift will never exceed the size of the forward pass, since the forward pass will also contain a RoPE of the same size. This does not have a sigificant impact on performance since RoPE is a math operation that is mostly proportional to the size of its inputs. In theory defrag could have the same issue since it also creates a compute graph outside of the forward pass, however, since it is only copies, it does not require any working space.
1 parent ec7445b commit ae0fd32

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

kvcache/causal.go

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ type Causal struct {
2525

2626
opts CausalOptions
2727

28+
// maxBatch is the largest batch that we might receive
29+
maxBatch int
30+
2831
// config controls mostly backend-specific optimizations
2932
config *ml.CacheConfig
3033

@@ -147,6 +150,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
147150
c.DType = dtype
148151
c.cellRanges = make(map[int]cellRange)
149152
c.backend = backend
153+
c.maxBatch = maxBatch
150154
}
151155

152156
func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -639,48 +643,51 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
639643
return ErrNotSupported
640644
}
641645

642-
ctx := c.backend.NewContext()
643-
defer ctx.Close()
644-
645646
seqRange := c.cellRanges[seq]
646-
size := seqRange.max - seqRange.min + 1
647647

648-
offsets := make([]int32, size)
649-
for i := range offsets {
650-
cell := c.cells[seqRange.min+i]
648+
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
649+
ctx := c.backend.NewContext()
651650

652-
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
653-
offsets[i] = offset
651+
size := min(seqRange.max-start+1, c.maxBatch)
652+
offsets := make([]int32, size)
653+
for i := range offsets {
654+
cell := c.cells[start+i]
655+
656+
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
657+
offsets[i] = offset
658+
}
654659
}
655-
}
656660

657-
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
661+
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
658662

659-
for i, key := range c.keys {
660-
if key == nil {
661-
continue
662-
}
663+
for i, key := range c.keys {
664+
if key == nil {
665+
continue
666+
}
663667

664-
kHeadDim := key.Dim(0)
665-
numKVHeads := key.Dim(1)
666-
rowSize := key.Stride(2)
668+
kHeadDim := key.Dim(0)
669+
numKVHeads := key.Dim(1)
670+
rowSize := key.Stride(2)
667671

668-
key = key.View(ctx, rowSize*seqRange.min,
669-
kHeadDim, key.Stride(1),
670-
numKVHeads, key.Stride(2),
671-
size,
672-
)
672+
key = key.View(ctx, rowSize*start,
673+
kHeadDim, key.Stride(1),
674+
numKVHeads, key.Stride(2),
675+
size,
676+
)
673677

674-
roped, err := c.shiftFn(ctx, i, key, kShift)
675-
if err != nil {
676-
return err
678+
roped, err := c.shiftFn(ctx, i, key, kShift)
679+
if err != nil {
680+
ctx.Close()
681+
return err
682+
}
683+
684+
ctx.Forward(roped.Copy(ctx, key))
677685
}
678686

679-
ctx.Forward(roped.Copy(ctx, key))
687+
ctx.Compute()
688+
ctx.Close()
680689
}
681690

682-
ctx.Compute()
683-
684691
return nil
685692
}
686693

0 commit comments

Comments
 (0)