Skip to content

Commit 8d3d4c2

Browse files
committed
Ensure logit padding happens on default stream
1 parent d9f0ecc commit 8d3d4c2

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

exllamav2/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,9 @@ def forward_chunk(self,
989989
if self.tp_context:
990990
self.tp_context.wait_streams()
991991

992+
if x is not None and x.is_cuda:
993+
torch.cuda.set_stream(torch.cuda.default_stream(x.device))
994+
992995
# Apply logit scale
993996

994997
# if x is not None and self.config.logit_scale != 1:

0 commit comments

Comments
 (0)