Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 711eb41

Browse files
committed
Fix cache_lane in forward
1 parent 0b0a436 commit 711eb41

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchchat/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def distribute(self, device_mesh: DeviceMesh):
653653
ColwiseParallel(output_layouts=Replicate()),
654654
)
655655

656-
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
656+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor:
657657
assert self.freqs_cis is not None, "Caches must be initialized first"
658658
mask = self.causal_mask[None, None, input_pos]
659659
freqs_cis = self.freqs_cis[input_pos]
@@ -686,7 +686,9 @@ def distribute(self, device_mesh: DeviceMesh):
686686
def forward(
687687
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
688688
) -> Tensor:
689-
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
689+
h = x + self.attention(
690+
self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane
691+
)
690692
out = h + self.feed_forward(self.ffn_norm(h))
691693
return out
692694

0 commit comments

Comments
 (0)