Skip to content

Commit 1e462f1

Browse files
committed
Force loading tensors on default stream
1 parent 0d5c0bc commit 1e462f1

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

exllamav2/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import math
5-
from exllamav2.fasttensors import STFile
5+
from exllamav2.fasttensors import STFile, cleanup_stfiles
66
from exllamav2.architecture import ExLlamaV2ArchParams
77
import os, glob, json
88
from typing import Any, Dict, List, TypeVar, Union, cast
@@ -370,7 +370,7 @@ def prepare(self, no_tensors: bool = False):
370370
if not match:
371371
raise ValueError(f" ## Could not find {prefix}.* in model")
372372

373-
x = 0
373+
cleanup_stfiles()
374374

375375

376376
def arch_compat_overrides(self, quiet: bool = False, warn_only = False):

exllamav2/fasttensors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def get_tensor(self,
191191

192192
torch.cuda.synchronize()
193193

194+
if device != "cpu":
195+
torch.cuda.set_stream(torch.cuda.default_stream(device))
196+
194197
if self.tensor_remap and (not_fast or not self.fast):
195198
key = self.tensor_remap[key]
196199

@@ -213,8 +216,6 @@ def get_tensor(self,
213216
size = end - beg
214217
numel = size // esize
215218
shape = h["shape"]
216-
if device != "cpu":
217-
torch.cuda.set_stream(torch.cuda.default_stream(device))
218219
tensor = torch.zeros(shape, dtype = dtype, device = device)
219220
assert tensor.is_contiguous, "Non-contiguous tensor"
220221
ext_c.safetensors_read_fb(self.handle_fb, beg + self.header_size, size, tensor)

exllamav2/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ def tp_split(self, broadcast_type: int, dim = None):
597597
)
598598
)
599599

600+
torch.cuda.synchronize()
601+
600602
ext_c.free_q_matrix(self.q_handle)
601603
self.q_handle = new_q_handle
602604
self.q_tensors = new_q_tensors

0 commit comments

Comments
 (0)