Skip to content

Commit ccaeeba

Browse files
awaelchlicarmocca
andauthored
Speed up quantization in generate.py (OpenGVLab#35)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent c409960 commit ccaeeba

File tree

5 files changed

+76
-31
lines changed

5 files changed

+76
-31
lines changed

generate.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import sys
32
import time
43
from pathlib import Path
@@ -7,8 +6,7 @@
76
import lightning as L
87
import torch
98

10-
from lit_llama.model import LLaMA
11-
from lit_llama.tokenizer import Tokenizer
9+
from lit_llama import LLaMA, Tokenizer, as_8_bit_quantized
1210

1311

1412
@torch.no_grad()
@@ -104,21 +102,13 @@ def main(
104102

105103
fabric = L.Fabric(accelerator=accelerator, devices=1)
106104

107-
if quantize:
108-
from lit_llama.quantization import quantize
109-
110-
print("Running quantization. This may take a minute ...")
111-
# TODO: Initializing the model directly on the device does not work with quantization
105+
with as_8_bit_quantized(fabric.device, enabled=quantize):
106+
print("Loading model ...", file=sys.stderr)
107+
t0 = time.time()
112108
model = LLaMA.from_name(model_size)
113-
# The output layer can be sensitive to quantization, we keep it in default precision
114-
model = quantize(model, skip=("lm_head", "output"))
115109
checkpoint = torch.load(checkpoint_path)
116110
model.load_state_dict(checkpoint)
117-
else:
118-
with fabric.device:
119-
model = LLaMA.from_name(model_size)
120-
checkpoint = torch.load(checkpoint_path)
121-
model.load_state_dict(checkpoint)
111+
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
122112

123113
model.eval()
124114

@@ -133,6 +123,7 @@ def main(
133123

134124
L.seed_everything(1234)
135125
t0 = time.time()
126+
136127
for _ in range(num_samples):
137128
y = generate(
138129
model,
@@ -144,8 +135,9 @@ def main(
144135
)[0] # unpack batch dimension
145136
print(tokenizer.decode(y))
146137

147-
print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr)
148-
print(f"Memory used (GB): {torch.cuda.max_memory_reserved() / 1e9:.02f}", file=sys.stderr)
138+
t = time.time() - t0
139+
print(f"\n\nTime for inference: {t:.02f} sec total, {max_new_tokens / t:.02f} tokens/sec", file=sys.stderr)
140+
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
149141

150142

151143
if __name__ == "__main__":

lit_llama/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from lit_llama.model import LLaMAConfig, LLaMA, RMSNorm, build_rope_cache, apply_rope
2+
from lit_llama.quantization import as_8_bit_quantized
23
from lit_llama.tokenizer import Tokenizer

lit_llama/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ def __init__(self, config: LLaMAConfig) -> None:
184184
)
185185
)
186186

187-
# init all weights
188-
self.apply(self._init_weights)
189-
190187
def _init_weights(self, module: nn.Module) -> None:
191188
if isinstance(module, nn.Linear):
192189
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

lit_llama/quantization.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,73 @@
11
import os
2-
from typing import Tuple
2+
from contextlib import contextmanager
3+
import warnings
34

4-
import torch.nn as nn
5+
import torch
56

7+
# configuration for bitsandbytes before import
68
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
9+
warnings.filterwarnings(
10+
"ignore",
11+
message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization"
12+
)
13+
warnings.filterwarnings(
14+
"ignore",
15+
message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable."
16+
)
717
import bitsandbytes as bnb # noqa: E402
818

919

10-
def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ()) -> nn.Module:
11-
for name, module in model.named_children():
12-
if isinstance(module, nn.Linear) and name not in skip:
13-
model._modules[name] = bnb.nn.Linear8bitLt(
14-
module.in_features, module.out_features, bias=module.bias, has_fp16_weights=False, threshold=threshold
15-
)
20+
class Linear8bitLt(bnb.nn.Linear8bitLt):
21+
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and
22+
re-quantizaton when loading the state dict.
23+
24+
25+
This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly.
26+
"""
27+
def __init__(self, *args, **kwargs):
28+
super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0)
29+
# We quantize the initial weight here so we don't end up filling the device
30+
# memory with float32 weights which could lead to OOM.
31+
self._quantize_weight(self.weight.data)
1632

17-
if module.children():
18-
quantize(module, threshold=threshold, skip=skip)
19-
return model
33+
def _load_from_state_dict(self, local_state_dict, *args, **kwargs):
34+
# There is only one key that ends with `*.weight`, the other one is the bias
35+
weight_key = next(name for name in local_state_dict.keys() if name.endswith("weight"))
36+
37+
# Load the weight from the state dict and re-quantize it
38+
weight = local_state_dict.pop(weight_key)
39+
self._quantize_weight(weight)
40+
41+
# If there is a bias, let nn.Module load it
42+
if local_state_dict:
43+
super()._load_from_state_dict(local_state_dict, *args, **kwargs)
44+
45+
def _quantize_weight(self, weight: torch.Tensor) -> None:
46+
# This code is taken and adapted from `bnb.nn.Int8Params.cuda()`
47+
B = weight.contiguous().half().cuda()
48+
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
49+
del CBt
50+
del SCBt
51+
self.weight.data = CB
52+
setattr(self.weight, "CB", CB)
53+
setattr(self.weight, "SCB", SCB)
54+
55+
56+
@contextmanager
57+
def as_8_bit_quantized(device: torch.device, enabled: bool = True):
58+
"""A context manager under which you can instantiate the model with 8-bit quantized tensors
59+
being created directly on the given device.
60+
"""
61+
62+
with torch.device(device):
63+
if not enabled:
64+
yield
65+
return
66+
67+
if device.type != "cuda":
68+
raise ValueError("Quantization is only supported on the GPU.")
69+
70+
torch_linear_cls = torch.nn.Linear
71+
torch.nn.Linear = Linear8bitLt
72+
yield
73+
torch.nn.Linear = torch_linear_cls

tests/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_to_orig_llama(lit_llama, orig_llama) -> None:
5252
)
5353

5454
llama_model = lit_llama.LLaMA(llama_config)
55+
llama_model.apply(llama_model._init_weights)
5556
orig_llama_model = orig_llama.Transformer(orig_llama_config)
5657

5758
copy_weights(llama_model, orig_llama_model)

0 commit comments

Comments
 (0)