Skip to content

Commit 3c92d75

Browse files
committed
model_init: Add --verbose argument and hide TP split by default
1 parent 448a738 commit 3c92d75

File tree

4 files changed

+18
-5
lines changed

4 files changed

+18
-5
lines changed

exllamav3/model/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def load_gen(
120120
callback: Callable[[int, int], None] | None = None,
121121
generator: bool = True,
122122
tp_dev_limits: dict | None = None,
123-
tp_backend: str = "native"
123+
tp_backend: str = "native",
124+
verbose: bool = False
124125
):
125126
"""
126127
Load model, generator function. For regular function, call load() with the same arguments
@@ -197,6 +198,9 @@ def load_gen(
197198
198199
:param tp_backend:
199200
str, either "nccl" (default) or "native"
201+
202+
:param verbose:
203+
bool, more info while loading including full TP split
200204
"""
201205

202206
free_mem()
@@ -214,7 +218,7 @@ def load_gen(
214218
"Cannot specify reserve_per_device or use_per_device when loading to single device."
215219
assert not tensor_p, \
216220
"Cannot use tensor_p when loading to single device."
217-
self._load_single(progressbar, device, self.config, self.modules)
221+
self._load_single(progressbar, device, self.config, self.modules, verbose)
218222

219223
# Use/reserve
220224
else:
@@ -264,6 +268,7 @@ def load_gen(
264268
generator,
265269
self.config,
266270
self.modules,
271+
verbose,
267272
)
268273
self.output_device = self.modules[-1].device
269274

@@ -290,6 +295,7 @@ def load_gen(
290295
self.modules,
291296
tp_dev_limits,
292297
tp_backend,
298+
verbose,
293299
)
294300
self.output_device = tp_output_device
295301

exllamav3/model/model_ls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def _load_single(
2424
device: torch.device,
2525
config: Config,
2626
modules: list,
27+
verbose: bool
2728
):
2829
with ProgressBar(f"Loading" if progressbar else None, len(modules)) as progress:
2930
for idx, module in enumerate(modules):
@@ -57,6 +58,7 @@ def _load_autosplit(
5758
generator: bool,
5859
config: Config,
5960
modules: list,
61+
verbose: bool
6062
):
6163
current_device_i = 0
6264
backup_shape, backup_dtype = self.default_load_shape_dtype(max_chunk_size)
@@ -65,7 +67,7 @@ def _load_autosplit(
6567
touched_devices = []
6668
params = self.default_load_params()
6769

68-
with ProgressBar(f"Loading" if progressbar else None, len(modules)) as progress:
70+
with ProgressBar(f"Loading (LS)" if progressbar else None, len(modules)) as progress:
6971

7072
for idx, module in enumerate(modules):
7173

exllamav3/model/model_tp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def _load_tp(
232232
modules: list,
233233
dev_limits: dict | None,
234234
tp_backend: str,
235+
verbose: bool
235236
):
236237
assert use_per_device is None or reserve_per_device is None
237238
if dev_limits is None: dev_limits = {}
@@ -272,7 +273,8 @@ def _load_tp(
272273
dev_limits = dev_limits,
273274
)
274275
allocator.initial_split(max_mem)
275-
allocator.print_split()
276+
if verbose:
277+
allocator.print_split()
276278
plan = allocator.compile_tp_plan()
277279
self.tp_worker_dispatch_wait_multi(self.active_devices, mp_set_plan, (plan, self.active_devices))
278280

@@ -286,7 +288,7 @@ def _load_tp(
286288
)
287289

288290
# Begin loading modules
289-
with (ProgressBar(f"Loading" if progressbar else None, len(modules)) as progress):
291+
with (ProgressBar(f"Loading (TP)" if progressbar else None, len(modules)) as progress):
290292
for idx, module in enumerate(modules):
291293
last_module = module
292294

exllamav3/model_init.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def add_args(
3434
parser.add_argument("-tp_moe", "--tp_max_parallelism_moe", type = int, help = "(TP) Maximum parallelism for MoE layers", default = None)
3535
parser.add_argument("-tp_linear", "--tp_max_parallelism_linear", type = int, help = "(TP) Maximum parallelism for linear (output) layers", default = None)
3636

37+
parser.add_argument("-v", "--verbose", action = "store_true", help = "Verbose output while loading")
38+
3739
if cache:
3840
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
3941
parser.add_argument("-cq", "--cache_quant", type = str, help = "Use quantized cache. Specify either kv_bits or k_bits,v_bits pair")
@@ -158,6 +160,7 @@ def printp(p: bool, s: str):
158160
progressbar = progress,
159161
tp_dev_limits = tp_dev_limits,
160162
tp_backend = args.tp_backend,
163+
verbose = args.verbose,
161164
**kwargs
162165
)
163166

0 commit comments

Comments
 (0)