Skip to content

Commit 7541421

Browse files
committed
convert.py: Add new parallel mode
1 parent 7f45c2a commit 7541421

File tree

3 files changed

+194
-80
lines changed

3 files changed

+194
-80
lines changed

exllamav3/conversion/convert_model.py

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from safetensors import safe_open
1717
import os, shutil
1818
import json
19+
import threading
1920

2021
col_default = "\u001b[0m"
2122
col_red = "\u001b[31;1m"
@@ -40,6 +41,7 @@
4041
parser.add_argument("-img", "--image_dump", action = "store_true", help = "Save model tensors as images (saved to working directory)")
4142
parser.add_argument("-cb", "--codebook", type = str, default = "mcg", help = "Codebook: mcg (default), mul1 or 3inst")
4243
parser.add_argument("-strat", "--strategy", type = str, default = None, help = "Modifiers for quantization strategy - EXPERIMENTAL")
44+
parser.add_argument("-pm", "--parallel_mode", action = "store_true", help = "When possible, use new parallel mode for small tensors (MoE layers especially)")
4345

4446
group = parser.add_mutually_exclusive_group()
4547
group.add_argument("--out_scales", dest = "out_scales_", action = "store_true", help = "Always enable out channel scales (for debug purposes)")
@@ -50,6 +52,10 @@
5052

5153
num_ref_states = 5
5254

55+
progress_lock = threading.Lock()
56+
curr_progress = 0
57+
max_progress = 0
58+
5359
def check_system():
5460
if os.environ.get("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") is not None:
5561
print(
@@ -167,6 +173,7 @@ def override(arg, can_override, default):
167173
("device_ratios", True, None),
168174
("codebook", True, "mcg"),
169175
("strategy", False, ""),
176+
("parallel_mode", True, False),
170177
]:
171178
override(arg_, can_override if not args.override_anyway else True, default)
172179

@@ -268,6 +275,7 @@ def mod_strategy(args, module, strategy, idx):
268275

269276
@torch.inference_mode()
270277
def main(args, job_state):
278+
global max_progress, curr_progress
271279

272280
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200)
273281

@@ -399,47 +407,136 @@ def main(args, job_state):
399407
for linear in linears:
400408
linear.inner.swap_cpu()
401409

402-
# Quantize module
403-
for linear in linears:
404-
quant_args = {
405-
"seed": idx,
406-
"K": strategy[linear.key],
407-
"devices": devices,
408-
"device_ratios": device_ratios,
409-
"apply_out_scales": args["apply_out_scales"],
410-
}
411-
if args["codebook"] == "mcg":
412-
quant_args.update({
413-
"mcg": True
414-
})
415-
elif args["codebook"] == "mul1":
416-
quant_args.update({
417-
"mul1": True
418-
})
419-
420-
with Timer() as t:
421-
sr = os.path.join(args["work_dir"], f"images/{linear.key}.reg.jpg") \
422-
if args["image_dump"] else None
423-
proxy_err = linear.convert_exl3(
424-
capture_H[linear.qmap],
425-
quant_args = quant_args,
426-
progress_str = f" -- <step>: {linear.key}",
427-
verbose = args["verbose"],
428-
save_reg = sr
410+
# Decide mode
411+
# TODO: Might be useful to compare no. h-tiles per tensor, no. layers and no. SMs across GPUs
412+
use_parallel_mode = False
413+
if args["parallel_mode"] and len(linears) >= len(devices):
414+
use_parallel_mode = True
415+
416+
# Quantize module, layer parallel
417+
if use_parallel_mode:
418+
assert not args["image_dump"], "Parallel mode is incompatible with --image_dump"
419+
420+
# Split workload
421+
all_dev_linears = [[] for _ in devices]
422+
423+
tot_numel = sum(linear.weights_numel() for linear in linears)
424+
if device_ratios is None:
425+
dev_numel = [tot_numel // len(devices) for _ in devices]
426+
else:
427+
tot_split = sum(device_ratios)
428+
dev_numel = [tot_numel * r // tot_split for _, r in zip(devices, device_ratios)]
429+
430+
for linear in linears:
431+
l_numel = linear.weights_numel()
432+
fit = [d_numel - l_numel for d_numel in dev_numel]
433+
bestfit = max(range(len(fit)), key = lambda x: fit[x])
434+
dev_numel[bestfit] -= l_numel
435+
all_dev_linears[bestfit].append(linear)
436+
437+
with progress_lock:
438+
curr_progress = 0
439+
max_progress = len(linears)
440+
441+
# Worker thread
442+
def work_thread(device_idx, dev_linears):
443+
global curr_progress
444+
445+
for linear in dev_linears:
446+
quant_args_local = {
447+
"seed": idx,
448+
"K": strategy[linear.key],
449+
"devices": [device_idx],
450+
"apply_out_scales": args["apply_out_scales"],
451+
}
452+
if args["codebook"] == "mcg": quant_args_local.update({ "mcg": True })
453+
elif args["codebook"] == "mul1": quant_args_local.update({ "mul1": True })
454+
455+
proxy_err = linear.convert_exl3(
456+
capture_H[linear.qmap],
457+
quant_args = quant_args_local,
458+
verbose = args["verbose"],
459+
save_reg = False,
460+
override_swap_device = device_idx
461+
)
462+
assert isinstance(linear.inner, LinearEXL3)
463+
linear.inner.swap_cpu()
464+
465+
flags = "o" if quant_args_local["apply_out_scales"] else "."
466+
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
467+
print(
468+
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
469+
f" bpw: {quant_args_local['K']:5.2f}"
470+
f" proxy_err: {proxy_err_str}"
471+
f" {flags}"
472+
f" g_sc: {quant_args_local['g_scale']:.6f}"
473+
)
474+
with progress_lock:
475+
curr_progress += 1
476+
477+
# Launch
478+
threads = []
479+
for i, device_idx in enumerate(devices):
480+
if len(all_dev_linears[i]):
481+
t = threading.Thread(target = work_thread, args = (device_idx, all_dev_linears[i]))
482+
t.daemon = True
483+
threads.append(t)
484+
for t in threads:
485+
t.start()
486+
487+
try:
488+
with ProgressBar(" -- Quantizing (parallel)", max_progress, transient = True) as progress:
489+
while any(t.is_alive() for t in threads):
490+
progress.update(curr_progress)
491+
time.sleep(0.1)
492+
except KeyboardInterrupt as e:
493+
# TODO: This is too hacky
494+
from signal import pthread_kill, SIGTSTP, SIGKILL
495+
for t in threads:
496+
pthread_kill(t.ident, SIGTSTP)
497+
pthread_kill(t.ident, SIGKILL)
498+
print("Aborted.")
499+
sys.exit()
500+
501+
for t in threads:
502+
t.join(timeout = 0.1)
503+
504+
# Quantize module, single GPU or tensor split
505+
else:
506+
for linear in linears:
507+
quant_args = {
508+
"seed": idx,
509+
"K": strategy[linear.key],
510+
"devices": devices,
511+
"device_ratios": device_ratios,
512+
"apply_out_scales": args["apply_out_scales"],
513+
}
514+
if args["codebook"] == "mcg": quant_args.update({ "mcg": True })
515+
elif args["codebook"] == "mul1": quant_args.update({ "mul1": True })
516+
517+
with Timer() as t:
518+
sr = os.path.join(args["work_dir"], f"images/{linear.key}.reg.jpg") \
519+
if args["image_dump"] else None
520+
proxy_err = linear.convert_exl3(
521+
capture_H[linear.qmap],
522+
quant_args = quant_args,
523+
progress_str = f" -- <step>: {linear.key}",
524+
verbose = args["verbose"],
525+
save_reg = sr,
526+
)
527+
assert isinstance(linear.inner, LinearEXL3)
528+
linear.inner.swap_cpu()
529+
flags = "o" if quant_args["apply_out_scales"] else "."
530+
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
531+
print(
532+
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
533+
f" bpw: {quant_args['K']:5.2f}"
534+
f" proxy_err: {proxy_err_str}"
535+
f" {flags}"
536+
f" g_sc: {quant_args['g_scale']:.6f}"
537+
f" [{t.interval:4.2f} s]"
429538
)
430-
assert isinstance(linear.inner, LinearEXL3)
431-
linear.inner.swap_cpu()
432-
flags = "o" if quant_args["apply_out_scales"] else "."
433-
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
434-
print(
435-
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
436-
f" bpw: {quant_args['K']:5.2f}"
437-
f" proxy_err: {proxy_err_str}"
438-
f" {flags}"
439-
f" g_sc: {quant_args['g_scale']:.6f}"
440-
f" [{t.interval:4.2f} s]"
441-
)
442-
sys.stdout.flush()
539+
sys.stdout.flush()
443540

444541
# Collect converted module tensors
445542
for m in module:

exllamav3/modules/linear.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,17 @@ def convert_exl3(
240240
progress_str: str | None = None,
241241
return_weight_q: bool = False,
242242
verbose: bool = False,
243-
save_reg: str = None
243+
save_reg: str = None,
244+
override_swap_device: torch.device | None = None
244245
):
245246
assert isinstance(self.inner, LinearFP16), \
246247
"Inner layer is already quant type"
247248

248249
# Destroy original layer here to save VRAM, we only need weights
249250
swap_to_device = self.inner.swap_device # in case weights are swapped to CPU
251+
if swap_to_device is not None and override_swap_device is not None:
252+
swap_to_device = override_swap_device
253+
250254
orig_weight = self.inner.get_weight_tensor().float()
251255
orig_bias = self.inner.get_bias_tensor()
252256
self.inner = None

exllamav3/modules/quant/exl3_lib/quantize.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ....util import cuda_sync_active
99
from ....util.tensor import save_tensor_image
1010
from functools import lru_cache
11+
import threading
1112

1213
# Constant
1314
had_k, had_n = 128, 128
@@ -442,56 +443,60 @@ def ldlq(
442443
return weight_q, encoded
443444

444445

446+
finalize_capture_H_mutex = threading.Lock()
447+
445448
def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
446-
# Unswap H
447-
if "H_swap_device" in H_data:
448-
H_data["H"] = H_data["H"].to(H_data["H_swap_device"])
449-
del H_data["H_swap_device"]
449+
with finalize_capture_H_mutex:
450450

451-
H = H_data["H"]
452-
if H_data["finalized"]:
453-
return H, H_data["L"], H_data["su"], H_data["diag"]
451+
# Unswap H
452+
if "H_swap_device" in H_data:
453+
H_data["H"] = H_data["H"].to(H_data["H_swap_device"])
454+
del H_data["H_swap_device"]
454455

455-
# Mean of samples summed up during forward pass
456-
H /= H_data["count"]
456+
H = H_data["H"]
457+
if H_data["finalized"]:
458+
return H, H_data["L"], H_data["su"], H_data["diag"]
457459

458-
# Regularize diagonal
459-
diag_mean = torch.diag(H).mean()
460-
idx = torch.arange(H.shape[0])
461-
H[idx, idx] += quant_args.get("sigma_reg", 0.025) * diag_mean
460+
# Mean of samples summed up during forward pass
461+
H /= H_data["count"]
462462

463-
# Some tests
464-
diag = H[idx, idx].clone()
463+
# Regularize diagonal
464+
diag_mean = torch.diag(H).mean()
465+
idx = torch.arange(H.shape[0])
466+
H[idx, idx] += quant_args.get("sigma_reg", 0.025) * diag_mean
465467

466-
if verbose:
467-
print(f" - H min/max: {H.min().item():.6f} {H.max().item():.6f}")
468-
print(f" - H mean/std: {H.mean().item():.6f} {H.std().item():.6f}")
469-
print(f" - H diag min/max: {diag.min():.6f} {diag.max():.6f} ")
468+
# Some tests
469+
diag = H[idx, idx].clone()
470+
471+
if verbose:
472+
print(f" - H min/max: {H.min().item():.6f} {H.max().item():.6f}")
473+
print(f" - H mean/std: {H.mean().item():.6f} {H.std().item():.6f}")
474+
print(f" - H diag min/max: {diag.min():.6f} {diag.max():.6f} ")
470475

471-
# Random sign flips for input channel, fixed for the first linear layer to quantize with this H
472-
k = H.shape[0]
473-
su = (torch.randn(k, device = H.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(1)
474-
H_data["su"] = su
476+
# Random sign flips for input channel, fixed for the first linear layer to quantize with this H
477+
k = H.shape[0]
478+
su = (torch.randn(k, device = H.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(1)
479+
H_data["su"] = su
475480

476-
# Input had
477-
H *= su.T
478-
blockwise_preapply_had_r_(H, had_k)
479-
H *= su
480-
blockwise_preapply_had_l_(H, had_k)
481+
# Input had
482+
H *= su.T
483+
blockwise_preapply_had_r_(H, had_k)
484+
H *= su
485+
blockwise_preapply_had_l_(H, had_k)
481486

482-
# Get block LDL decomposition of H, zero diagonal
483-
L, H = block_ldl(H, 16, verbose)
484-
dr = torch.arange(k)
485-
L[dr, dr] = 0
486-
H_data["L"] = L
487+
# Get block LDL decomposition of H, zero diagonal
488+
L, H = block_ldl(H, 16, verbose)
489+
dr = torch.arange(k)
490+
L[dr, dr] = 0
491+
H_data["L"] = L
487492

488-
# H is no longer needed except to compute proxy error so move to CPU
489-
H = H.cpu()
490-
H_data["H"] = H.cpu()
493+
# H is no longer needed except to compute proxy error so move to CPU
494+
H = H.cpu()
495+
H_data["H"] = H.cpu()
491496

492-
H_data["finalized"] = True
493-
H_data["diag"] = diag
494-
return H, L, su, diag
497+
H_data["finalized"] = True
498+
H_data["diag"] = diag
499+
return H, L, su, diag
495500

496501

497502
def pack_trellis(encoded: torch.Tensor, quant_args: dict) -> torch.Tensor:
@@ -777,11 +782,19 @@ def quantize_exl3(
777782
if "seed" in quant_args:
778783
torch.manual_seed(quant_args["seed"])
779784

785+
devices = quant_args["devices"]
786+
if weight.device != torch.device(devices[0]):
787+
weight = weight.to(devices[0])
788+
780789
device = weight.device if swap_to_device is None else swap_to_device
781790
k, n = weight.shape
782791

783792
# Get H, LDL decomp. and input/output sign flips
784793
H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose)
794+
if H.is_cuda: H = H.to(device)
795+
if L.is_cuda: L = L.to(device)
796+
if su.is_cuda: su = su.to(device)
797+
if H_diag.is_cuda: H_diag = H_diag.to(device)
785798
sv = (torch.randn(n, device = device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0)
786799

787800
# Move stored L to CPU (if not already), move working L to device

0 commit comments

Comments
 (0)