47
47
from exllamav2 .pos_embedding import ExLlamaV2PosEmbedding
48
48
from exllamav2 .compat import safe_move_tensor
49
49
from exllamav2 .fasttensors import cleanup_stfiles
50
- from exllamav2 .device import ExLlamaV2DeviceContext
50
+ from exllamav2 .device import ExLlamaV2DeviceContext , set_device_streams
51
51
from exllamav2 .tensor_p import TPContext , BROADCAST_VC
52
52
import gc
53
53
import threading
@@ -923,6 +923,10 @@ def forward_chunk(self,
923
923
seq_len <= self .config .max_output_len , \
924
924
"seq_len exceeds max_output_len"
925
925
926
+ # Ensure streams are always set in the current thread
927
+
928
+ set_device_streams ()
929
+
926
930
# Output
927
931
928
932
r = {}
@@ -944,10 +948,6 @@ def forward_chunk(self,
944
948
cache .current_seq_len = past_len
945
949
946
950
device = self .modules [0 ].device_idx
947
- if device is not None and device >= 0 :
948
- context = self .get_device_context (device )
949
- if context :
950
- torch .cuda .set_stream (context .stream )
951
951
952
952
for idx , module in enumerate (self .modules ):
953
953
@@ -969,9 +969,6 @@ def forward_chunk(self,
969
969
n_device = module .device_idx
970
970
if n_device is not None and n_device != device and n_device >= 0 :
971
971
x = safe_move_tensor (x , n_device , non_blocking = True )
972
- device = n_device
973
- context = self .get_device_context (device )
974
- torch .cuda .set_stream (context .stream )
975
972
976
973
x = module .forward (x , cache = cache , attn_params = attn_params , past_len = past_len , loras = loras , ** kwargs )
977
974
0 commit comments