|
16 | 16 | from safetensors import safe_open |
17 | 17 | import os, shutil |
18 | 18 | import json |
| 19 | +import threading |
19 | 20 |
|
20 | 21 | col_default = "\u001b[0m" |
21 | 22 | col_red = "\u001b[31;1m" |
|
40 | 41 | parser.add_argument("-img", "--image_dump", action = "store_true", help = "Save model tensors as images (saved to working directory)") |
41 | 42 | parser.add_argument("-cb", "--codebook", type = str, default = "mcg", help = "Codebook: mcg (default), mul1 or 3inst") |
42 | 43 | parser.add_argument("-strat", "--strategy", type = str, default = None, help = "Modifiers for quantization strategy - EXPERIMENTAL") |
| 44 | +parser.add_argument("-pm", "--parallel_mode", action = "store_true", help = "When possible, use new parallel mode for small tensors (MoE layers especially)") |
43 | 45 |
|
44 | 46 | group = parser.add_mutually_exclusive_group() |
45 | 47 | group.add_argument("--out_scales", dest = "out_scales_", action = "store_true", help = "Always enable out channel scales (for debug purposes)") |
|
50 | 52 |
|
51 | 53 | num_ref_states = 5 |
52 | 54 |
|
| 55 | +progress_lock = threading.Lock() |
| 56 | +curr_progress = 0 |
| 57 | +max_progress = 0 |
| 58 | + |
53 | 59 | def check_system(): |
54 | 60 | if os.environ.get("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") is not None: |
55 | 61 | print( |
@@ -167,6 +173,7 @@ def override(arg, can_override, default): |
167 | 173 | ("device_ratios", True, None), |
168 | 174 | ("codebook", True, "mcg"), |
169 | 175 | ("strategy", False, ""), |
| 176 | + ("parallel_mode", True, False), |
170 | 177 | ]: |
171 | 178 | override(arg_, can_override if not args.override_anyway else True, default) |
172 | 179 |
|
@@ -268,6 +275,7 @@ def mod_strategy(args, module, strategy, idx): |
268 | 275 |
|
269 | 276 | @torch.inference_mode() |
270 | 277 | def main(args, job_state): |
| 278 | + global max_progress, curr_progress |
271 | 279 |
|
272 | 280 | torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 200) |
273 | 281 |
|
@@ -399,47 +407,136 @@ def main(args, job_state): |
399 | 407 | for linear in linears: |
400 | 408 | linear.inner.swap_cpu() |
401 | 409 |
|
402 | | - # Quantize module |
403 | | - for linear in linears: |
404 | | - quant_args = { |
405 | | - "seed": idx, |
406 | | - "K": strategy[linear.key], |
407 | | - "devices": devices, |
408 | | - "device_ratios": device_ratios, |
409 | | - "apply_out_scales": args["apply_out_scales"], |
410 | | - } |
411 | | - if args["codebook"] == "mcg": |
412 | | - quant_args.update({ |
413 | | - "mcg": True |
414 | | - }) |
415 | | - elif args["codebook"] == "mul1": |
416 | | - quant_args.update({ |
417 | | - "mul1": True |
418 | | - }) |
419 | | - |
420 | | - with Timer() as t: |
421 | | - sr = os.path.join(args["work_dir"], f"images/{linear.key}.reg.jpg") \ |
422 | | - if args["image_dump"] else None |
423 | | - proxy_err = linear.convert_exl3( |
424 | | - capture_H[linear.qmap], |
425 | | - quant_args = quant_args, |
426 | | - progress_str = f" -- <step>: {linear.key}", |
427 | | - verbose = args["verbose"], |
428 | | - save_reg = sr |
| 410 | + # Decide mode |
| 411 | + # TODO: Might be useful to compare no. h-tiles per tensor, no. layers and no. SMs across GPUs |
| 412 | + use_parallel_mode = False |
| 413 | + if args["parallel_mode"] and len(linears) >= len(devices): |
| 414 | + use_parallel_mode = True |
| 415 | + |
| 416 | + # Quantize module, layer parallel |
| 417 | + if use_parallel_mode: |
| 418 | + assert not args["image_dump"], "Parallel mode is incompatible with --image_dump" |
| 419 | + |
| 420 | + # Split workload |
| 421 | + all_dev_linears = [[] for _ in devices] |
| 422 | + |
| 423 | + tot_numel = sum(linear.weights_numel() for linear in linears) |
| 424 | + if device_ratios is None: |
| 425 | + dev_numel = [tot_numel // len(devices) for _ in devices] |
| 426 | + else: |
| 427 | + tot_split = sum(device_ratios) |
| 428 | + dev_numel = [tot_numel * r // tot_split for _, r in zip(devices, device_ratios)] |
| 429 | + |
| 430 | + for linear in linears: |
| 431 | + l_numel = linear.weights_numel() |
| 432 | + fit = [d_numel - l_numel for d_numel in dev_numel] |
| 433 | + bestfit = max(range(len(fit)), key = lambda x: fit[x]) |
| 434 | + dev_numel[bestfit] -= l_numel |
| 435 | + all_dev_linears[bestfit].append(linear) |
| 436 | + |
| 437 | + with progress_lock: |
| 438 | + curr_progress = 0 |
| 439 | + max_progress = len(linears) |
| 440 | + |
| 441 | + # Worker thread |
| 442 | + def work_thread(device_idx, dev_linears): |
| 443 | + global curr_progress |
| 444 | + |
| 445 | + for linear in dev_linears: |
| 446 | + quant_args_local = { |
| 447 | + "seed": idx, |
| 448 | + "K": strategy[linear.key], |
| 449 | + "devices": [device_idx], |
| 450 | + "apply_out_scales": args["apply_out_scales"], |
| 451 | + } |
| 452 | + if args["codebook"] == "mcg": quant_args_local.update({ "mcg": True }) |
| 453 | + elif args["codebook"] == "mul1": quant_args_local.update({ "mul1": True }) |
| 454 | + |
| 455 | + proxy_err = linear.convert_exl3( |
| 456 | + capture_H[linear.qmap], |
| 457 | + quant_args = quant_args_local, |
| 458 | + verbose = args["verbose"], |
| 459 | + save_reg = False, |
| 460 | + override_swap_device = device_idx |
| 461 | + ) |
| 462 | + assert isinstance(linear.inner, LinearEXL3) |
| 463 | + linear.inner.swap_cpu() |
| 464 | + |
| 465 | + flags = "o" if quant_args_local["apply_out_scales"] else "." |
| 466 | + proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) " |
| 467 | + print( |
| 468 | + f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}" |
| 469 | + f" bpw: {quant_args_local['K']:5.2f}" |
| 470 | + f" proxy_err: {proxy_err_str}" |
| 471 | + f" {flags}" |
| 472 | + f" g_sc: {quant_args_local['g_scale']:.6f}" |
| 473 | + ) |
| 474 | + with progress_lock: |
| 475 | + curr_progress += 1 |
| 476 | + |
| 477 | + # Launch |
| 478 | + threads = [] |
| 479 | + for i, device_idx in enumerate(devices): |
| 480 | + if len(all_dev_linears[i]): |
| 481 | + t = threading.Thread(target = work_thread, args = (device_idx, all_dev_linears[i])) |
| 482 | + t.daemon = True |
| 483 | + threads.append(t) |
| 484 | + for t in threads: |
| 485 | + t.start() |
| 486 | + |
| 487 | + try: |
| 488 | + with ProgressBar(" -- Quantizing (parallel)", max_progress, transient = True) as progress: |
| 489 | + while any(t.is_alive() for t in threads): |
| 490 | + progress.update(curr_progress) |
| 491 | + time.sleep(0.1) |
| 492 | + except KeyboardInterrupt as e: |
| 493 | + # TODO: This is too hacky |
| 494 | + from signal import pthread_kill, SIGTSTP, SIGKILL |
| 495 | + for t in threads: |
| 496 | + pthread_kill(t.ident, SIGTSTP) |
| 497 | + pthread_kill(t.ident, SIGKILL) |
| 498 | + print("Aborted.") |
| 499 | + sys.exit() |
| 500 | + |
| 501 | + for t in threads: |
| 502 | + t.join(timeout = 0.1) |
| 503 | + |
| 504 | + # Quantize module, single GPU or tensor split |
| 505 | + else: |
| 506 | + for linear in linears: |
| 507 | + quant_args = { |
| 508 | + "seed": idx, |
| 509 | + "K": strategy[linear.key], |
| 510 | + "devices": devices, |
| 511 | + "device_ratios": device_ratios, |
| 512 | + "apply_out_scales": args["apply_out_scales"], |
| 513 | + } |
| 514 | + if args["codebook"] == "mcg": quant_args.update({ "mcg": True }) |
| 515 | + elif args["codebook"] == "mul1": quant_args.update({ "mul1": True }) |
| 516 | + |
| 517 | + with Timer() as t: |
| 518 | + sr = os.path.join(args["work_dir"], f"images/{linear.key}.reg.jpg") \ |
| 519 | + if args["image_dump"] else None |
| 520 | + proxy_err = linear.convert_exl3( |
| 521 | + capture_H[linear.qmap], |
| 522 | + quant_args = quant_args, |
| 523 | + progress_str = f" -- <step>: {linear.key}", |
| 524 | + verbose = args["verbose"], |
| 525 | + save_reg = sr, |
| 526 | + ) |
| 527 | + assert isinstance(linear.inner, LinearEXL3) |
| 528 | + linear.inner.swap_cpu() |
| 529 | + flags = "o" if quant_args["apply_out_scales"] else "." |
| 530 | + proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) " |
| 531 | + print( |
| 532 | + f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}" |
| 533 | + f" bpw: {quant_args['K']:5.2f}" |
| 534 | + f" proxy_err: {proxy_err_str}" |
| 535 | + f" {flags}" |
| 536 | + f" g_sc: {quant_args['g_scale']:.6f}" |
| 537 | + f" [{t.interval:4.2f} s]" |
429 | 538 | ) |
430 | | - assert isinstance(linear.inner, LinearEXL3) |
431 | | - linear.inner.swap_cpu() |
432 | | - flags = "o" if quant_args["apply_out_scales"] else "." |
433 | | - proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) " |
434 | | - print( |
435 | | - f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}" |
436 | | - f" bpw: {quant_args['K']:5.2f}" |
437 | | - f" proxy_err: {proxy_err_str}" |
438 | | - f" {flags}" |
439 | | - f" g_sc: {quant_args['g_scale']:.6f}" |
440 | | - f" [{t.interval:4.2f} s]" |
441 | | - ) |
442 | | - sys.stdout.flush() |
| 539 | + sys.stdout.flush() |
443 | 540 |
|
444 | 541 | # Collect converted module tensors |
445 | 542 | for m in module: |
|
0 commit comments