|
1 | 1 | import os |
2 | | -from typing import Tuple |
| 2 | +from contextlib import contextmanager |
| 3 | +import warnings |
3 | 4 |
|
4 | | -import torch.nn as nn |
| 5 | +import torch |
5 | 6 |
|
| 7 | +# configuration for bitsandbytes before import |
6 | 8 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
| 9 | +warnings.filterwarnings( |
| 10 | + "ignore", |
| 11 | + message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization" |
| 12 | +) |
| 13 | +warnings.filterwarnings( |
| 14 | + "ignore", |
| 15 | + message="The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable." |
| 16 | +) |
7 | 17 | import bitsandbytes as bnb # noqa: E402 |
8 | 18 |
|
9 | 19 |
|
10 | | -def quantize(model: nn.Module, threshold: float = 6.0, skip: Tuple[str, ...] = ()) -> nn.Module: |
11 | | - for name, module in model.named_children(): |
12 | | - if isinstance(module, nn.Linear) and name not in skip: |
13 | | - model._modules[name] = bnb.nn.Linear8bitLt( |
14 | | - module.in_features, module.out_features, bias=module.bias, has_fp16_weights=False, threshold=threshold |
15 | | - ) |
| 20 | +class Linear8bitLt(bnb.nn.Linear8bitLt): |
| 21 | + """Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and |
| 22 | + re-quantizaton when loading the state dict. |
| 23 | + |
| 24 | + |
| 25 | + This should only be used for inference. For training, use `bnb.nn.Linear8bitLt` directly. |
| 26 | + """ |
| 27 | + def __init__(self, *args, **kwargs): |
| 28 | + super().__init__(*args, **kwargs, has_fp16_weights=False, threshold=6.0) |
| 29 | + # We quantize the initial weight here so we don't end up filling the device |
| 30 | + # memory with float32 weights which could lead to OOM. |
| 31 | + self._quantize_weight(self.weight.data) |
16 | 32 |
|
17 | | - if module.children(): |
18 | | - quantize(module, threshold=threshold, skip=skip) |
19 | | - return model |
| 33 | + def _load_from_state_dict(self, local_state_dict, *args, **kwargs): |
| 34 | + # There is only one key that ends with `*.weight`, the other one is the bias |
| 35 | + weight_key = next(name for name in local_state_dict.keys() if name.endswith("weight")) |
| 36 | + |
| 37 | + # Load the weight from the state dict and re-quantize it |
| 38 | + weight = local_state_dict.pop(weight_key) |
| 39 | + self._quantize_weight(weight) |
| 40 | + |
| 41 | + # If there is a bias, let nn.Module load it |
| 42 | + if local_state_dict: |
| 43 | + super()._load_from_state_dict(local_state_dict, *args, **kwargs) |
| 44 | + |
| 45 | + def _quantize_weight(self, weight: torch.Tensor) -> None: |
| 46 | + # This code is taken and adapted from `bnb.nn.Int8Params.cuda()` |
| 47 | + B = weight.contiguous().half().cuda() |
| 48 | + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) |
| 49 | + del CBt |
| 50 | + del SCBt |
| 51 | + self.weight.data = CB |
| 52 | + setattr(self.weight, "CB", CB) |
| 53 | + setattr(self.weight, "SCB", SCB) |
| 54 | + |
| 55 | + |
| 56 | +@contextmanager |
| 57 | +def as_8_bit_quantized(device: torch.device, enabled: bool = True): |
| 58 | + """A context manager under which you can instantiate the model with 8-bit quantized tensors |
| 59 | + being created directly on the given device. |
| 60 | + """ |
| 61 | + |
| 62 | + with torch.device(device): |
| 63 | + if not enabled: |
| 64 | + yield |
| 65 | + return |
| 66 | + |
| 67 | + if device.type != "cuda": |
| 68 | + raise ValueError("Quantization is only supported on the GPU.") |
| 69 | + |
| 70 | + torch_linear_cls = torch.nn.Linear |
| 71 | + torch.nn.Linear = Linear8bitLt |
| 72 | + yield |
| 73 | + torch.nn.Linear = torch_linear_cls |
0 commit comments