Skip to content

Commit d3fe9f2

Browse files
committed
Unmap tensors on CPU to reduce temp VRAM overhead while loading
1 parent 7e15947 commit d3fe9f2

File tree

6 files changed

+95
-14
lines changed

6 files changed

+95
-14
lines changed

exllamav2/exllamav2_ext/cpp/safetensors.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,65 @@ void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tenso
453453
remaining -= chunk;
454454
}
455455
}
456-
}
456+
}
457+
458+
void tensor_remap
459+
(
460+
torch::Tensor tensor,
461+
torch::Tensor index
462+
)
463+
{
464+
TORCH_CHECK_SHAPES(tensor, 1, index, 0, 1);
465+
TORCH_CHECK_DTYPE(tensor, kInt);
466+
TORCH_CHECK_DTYPE(index, kInt);
467+
468+
int rows = tensor.size(0);
469+
int cols = tensor.size(1);
470+
uint32_t* temp = (uint32_t*) calloc(cols, sizeof(int));
471+
uint32_t* a = (uint32_t*) tensor.data_ptr();
472+
uint32_t* idx = (uint32_t*) index.data_ptr();
473+
474+
for (int r = 0; r < rows; ++r)
475+
{
476+
memcpy(temp, a, sizeof(uint32_t) * cols);
477+
for (int c = 0; c < cols; ++c)
478+
{
479+
*a++ = temp[idx[c]];
480+
}
481+
}
482+
free(temp);
483+
}
484+
485+
void tensor_remap_4bit
486+
(
487+
torch::Tensor tensor,
488+
torch::Tensor index
489+
)
490+
{
491+
TORCH_CHECK_SHAPES(index, 0, tensor, 1, 8);
492+
TORCH_CHECK_DTYPE(tensor, kInt);
493+
TORCH_CHECK_DTYPE(index, kInt);
494+
495+
int rows = tensor.size(0);
496+
int cols = index.size(0);
497+
uint32_t* temp = (uint32_t*) calloc(cols / 8, sizeof(int));
498+
uint32_t* a = (uint32_t*) tensor.data_ptr();
499+
uint32_t* idx = (uint32_t*) index.data_ptr();
500+
501+
for (int r = 0; r < rows; ++r)
502+
{
503+
memcpy(temp, a, sizeof(uint32_t) * cols / 8);
504+
for (int c = 0; c < cols;)
505+
{
506+
uint32_t rv = 0;
507+
for (int b = 0; b < 8; ++b, ++c)
508+
{
509+
uint32_t i = idx[c];
510+
uint32_t v = (temp[i / 8] >> ((i & 7) * 4) & 0x0f);
511+
rv |= v << (b * 4);
512+
}
513+
*a++ = rv;
514+
}
515+
}
516+
free(temp);
517+
}

exllamav2/exllamav2_ext/cpp/safetensors.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,17 @@ uintptr_t safetensors_open_fb(const char* filename);
4747
void safetensors_close_fb(uintptr_t handle);
4848
void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tensor target);
4949

50+
void tensor_remap
51+
(
52+
torch::Tensor tensor,
53+
torch::Tensor index
54+
);
55+
56+
void tensor_remap_4bit
57+
(
58+
torch::Tensor tensor,
59+
torch::Tensor index
60+
);
61+
62+
5063
#endif

exllamav2/exllamav2_ext/ext_bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5555
m.def("safetensors_pinned_buffer", &safetensors_pinned_buffer, "safetensors_pinned_buffer");
5656
m.def("safetensors_free_pinned_buffer", &safetensors_free_pinned_buffer, "safetensors_free_pinned_buffer");
5757
m.def("safetensors_read_fb", &safetensors_read_fb, "safetensors_read_fb");
58+
m.def("tensor_remap", &tensor_remap, "tensor_remap");
59+
m.def("tensor_remap_4bit", &tensor_remap_4bit, "tensor_remap_4bit");
5860

5961
// qmatrix
6062

exllamav2/ext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def find_msvc():
173173
# gcc / cl.exe flags
174174

175175
if windows:
176-
extra_cflags = ["/Ox", "/openmp"]
176+
extra_cflags = ["/Ox"]
177177
else:
178-
extra_cflags = ["-Ofast", "-fopenmp"]
178+
extra_cflags = ["-Ofast"]
179179

180180
if ext_debug:
181181
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]

exllamav2/linear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from exllamav2.compat import safe_move_tensor
99
from exllamav2.tensor_p import BROADCAST_VC
1010
from exllamav2.util import unpack_4bit, pack_4bit
11+
import gc
1112

1213
from typing import TYPE_CHECKING
1314

@@ -118,7 +119,7 @@ def load(self,
118119
cfg = self.model.config
119120

120121
if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
121-
if w is None: w = self.load_weight()
122+
if w is None: w = self.load_weight(cpu = output_map is not None)
122123

123124
# Load quantized linear layer from dictionary
124125

@@ -137,7 +138,7 @@ def load(self,
137138
self.q_tensors = w
138139

139140
if unmap and "q_perm" in w:
140-
perm = w["q_perm"]
141+
perm = w["q_perm"].cpu()
141142
del w["q_perm"]
142143
del w["q_invperm"]
143144
# w["q_perm"] = torch.arange(0, w["q_perm"].shape[-1], dtype = w["q_perm"].dtype, device = w["q_perm"].device)
@@ -146,8 +147,10 @@ def load(self,
146147
perm = None
147148

148149
if output_map is not None:
149-
w["q_weight"] = w["q_weight"][:, output_map]
150-
w["q_scale"] = pack_4bit(unpack_4bit(w["q_scale"])[:, output_map])
150+
ext_c.tensor_remap(w["q_weight"], output_map)
151+
ext_c.tensor_remap_4bit(w["q_scale"], output_map)
152+
for k in w.keys():
153+
w[k] = safe_move_tensor(w[k], self.device())
151154

152155
self.q_handle = ext.make_q_matrix(w,
153156
self.temp_dq,

exllamav2/module.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def device(self) -> str:
6060
def load_multi(self,
6161
key: str,
6262
keys: list[str],
63-
measure: bool = False) -> int | dict[str: torch.Tensor]:
63+
measure: bool = False,
64+
cpu: bool = False) -> int | dict[str: torch.Tensor]:
6465

6566
tensors = {}
6667
submap = {}
@@ -85,13 +86,14 @@ def load_multi(self,
8586
if measure:
8687
size += stfile.measure(key + "." + k)
8788
else:
88-
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device())
89+
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu")
8990

9091
return size if measure else tensors
9192

9293

9394
def load_weight(self,
94-
override_key: str | None = None):
95+
override_key: str | None = None,
96+
cpu: bool = False):
9597

9698
if override_key is not None:
9799
keys = [override_key]
@@ -105,14 +107,14 @@ def load_weight(self,
105107
# EXL2
106108

107109
if key + ".q_weight" in self.model.config.tensor_file_map:
108-
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"])
110+
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], cpu = cpu)
109111
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
110112
return qtensors
111113

112114
# GPTQ
113115

114116
if key + ".qweight" in self.model.config.tensor_file_map:
115-
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"])
117+
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"], cpu = cpu)
116118
if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)):
117119
del qtensors["bias"]
118120
qtensors["scales"] = qtensors["scales"].half()
@@ -122,14 +124,14 @@ def load_weight(self,
122124

123125
if key + ".weight" in self.model.config.tensor_file_map:
124126
if key + ".bias" in self.model.config.tensor_file_map:
125-
tensors = self.load_multi(key, ["weight", "bias"])
127+
tensors = self.load_multi(key, ["weight", "bias"], cpu = cpu)
126128
tensor = tensors["weight"].half()
127129
bias = tensors["bias"].half()
128130
if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2:
129131
tensor = tensor.T
130132
return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False)
131133
else:
132-
tensors = self.load_multi(key, ["weight"])
134+
tensors = self.load_multi(key, ["weight"], cpu = cpu)
133135
tensor = tensors["weight"].half()
134136
# if self.model.config.arch.orig_weights_transposed:
135137
# tensor = tensor.T

0 commit comments

Comments
 (0)