Skip to content

Commit affdc0d

Browse files
committed
Ensure streams are always set during the forward pass for the active thread
1 parent 5c455c1 commit affdc0d

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

exllamav2/device.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def _torch_device(idx):
1717
global_streams = {}
1818

1919

20+
def set_device_streams():
21+
global global_streams
22+
for(k, v) in global_streams.items():
23+
with torch.cuda.device(torch.device(k)):
24+
torch.cuda.set_stream(v)
25+
26+
2027
class ExLlamaV2DeviceContext:
2128

2229
model: ExLlamaV2

exllamav2/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from exllamav2.pos_embedding import ExLlamaV2PosEmbedding
4848
from exllamav2.compat import safe_move_tensor
4949
from exllamav2.fasttensors import cleanup_stfiles
50-
from exllamav2.device import ExLlamaV2DeviceContext
50+
from exllamav2.device import ExLlamaV2DeviceContext, set_device_streams
5151
from exllamav2.tensor_p import TPContext, BROADCAST_VC
5252
import gc
5353
import threading
@@ -923,6 +923,10 @@ def forward_chunk(self,
923923
seq_len <= self.config.max_output_len, \
924924
"seq_len exceeds max_output_len"
925925

926+
# Ensure streams are always set in the current thread
927+
928+
set_device_streams()
929+
926930
# Output
927931

928932
r = {}
@@ -944,10 +948,6 @@ def forward_chunk(self,
944948
cache.current_seq_len = past_len
945949

946950
device = self.modules[0].device_idx
947-
if device is not None and device >= 0:
948-
context = self.get_device_context(device)
949-
if context:
950-
torch.cuda.set_stream(context.stream)
951951

952952
for idx, module in enumerate(self.modules):
953953

@@ -969,9 +969,6 @@ def forward_chunk(self,
969969
n_device = module.device_idx
970970
if n_device is not None and n_device != device and n_device >= 0:
971971
x = safe_move_tensor(x, n_device, non_blocking = True)
972-
device = n_device
973-
context = self.get_device_context(device)
974-
torch.cuda.set_stream(context.stream)
975972

976973
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras, **kwargs)
977974

0 commit comments

Comments
 (0)