diff --git a/docs/.nav.yml b/docs/.nav.yml index 2d0236b248..ee0c0d0004 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -27,6 +27,9 @@ nav: - FP8 Example: key-models/mistral-large-3/fp8-example.md - Guides: - Compression Schemes: guides/compression_schemes.md + - Sequential Onloading: guides/sequential_onloading.md + - Model Loading: guides/model_loading.md + - Distributed Oneshot: guides/distributed_oneshot.md - Saving a Model: guides/saving_a_model.md - Observers: guides/observers.md - Memory Requirements: guides/memory.md diff --git a/docs/assets/model_graph.jpg b/docs/assets/model_graph.jpg new file mode 100644 index 0000000000..5fa7b771c6 Binary files /dev/null and b/docs/assets/model_graph.jpg differ diff --git a/docs/assets/seq_targets.jpg b/docs/assets/seq_targets.jpg new file mode 100644 index 0000000000..50e4e4773e Binary files /dev/null and b/docs/assets/seq_targets.jpg differ diff --git a/docs/assets/sequential_decoder_layers.jpg b/docs/assets/sequential_decoder_layers.jpg new file mode 100644 index 0000000000..1b559ba90c Binary files /dev/null and b/docs/assets/sequential_decoder_layers.jpg differ diff --git a/docs/assets/sequential_onloading.jpg b/docs/assets/sequential_onloading.jpg new file mode 100644 index 0000000000..5bc55353e9 Binary files /dev/null and b/docs/assets/sequential_onloading.jpg differ diff --git a/docs/guides/distributed_oneshot.md b/docs/guides/distributed_oneshot.md new file mode 100644 index 0000000000..66eb06618c --- /dev/null +++ b/docs/guides/distributed_oneshot.md @@ -0,0 +1,71 @@ +## Distributed Oneshot ## +As an experimental feature, LLM Compressor supports distributed oneshot for the purpose of greatly speeding up the runtime of model calibration and compression. For more information on implementation, see [[RFC] [Performance Refactor][Distributed] Sequential Onloading with Data-Parallel Calibration and Weight-Parallel Optimization](https://github.com/vllm-project/llm-compressor/issues/2180) as well as [[GPTQ][ddp] enabling DDP for GPTQ](https://github.com/vllm-project/llm-compressor/pull/2333). + +## Usage ## +In order to convert a script meant for single-threaded compression into one of distributed compression, please make the following changes: + +### 1. Initialize the Distributed Context ### + +In order to utilize the `torch.distributed` module, each rank must initialize the distributed module and assign itself a separate GPU device. This can be done by calling the `init_dist` utility provided by `compressed_tensors`. + +```python +from compressed_tensors.offload import init_dist + +init_dist() +``` + +### 2. Modify Model Loading ### + +In order to prevent separate processes from loading the model multiple times and creating excess work/memory usage, we must load our model using the `load_offloaded_model` context. For more information, see [Model Loading](./model_loading.md#distributed-oneshot). + +Before: +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype="auto" +) +``` + +After: +```python +from compressed_tensors.offload import load_offloaded_model + +with load_offloaded_model(): + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype="auto", + device_map="auto_offload", + ) +``` + +### 3. Modify Dataset Loading ### + +In order to prevent separate processes loading the entire dataset and creating excess work/memory usage, we must partition our dataset into disjoint sets. For a dataset of *N* samples and *R* ranks, each rank only loads *N/R* samples. + +```python +ds = load_dataset( + DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]" +) +``` + + +```python +from llmcompressor.datasets.utils import get_rank_partition + +ds = load_dataset( + DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) +) +``` + +### 4. Call your script with `torchrun` ### + +Now, your script is ready to run using distributed processes. To start, simply run your script using `python3 -m torchrun --nproc_per_node=2 YOUR_EXAMPLE.py` to run with two GPU devices. For a complete example script, see [llama_ddp_example.py](/examples/quantization_w4a16/llama3_ddp_example.py). The below table shows results and speedups as of LLM Compressor v0.10.0, future changes will bring these numbers closer to linear speedups. + +| model_id | world_size | max_time | max_memory | save_time | flex_extract | eval_time | +|----------|-------------|----------|------------|-----------|--------------|-----------| +| Meta-Llama-3-8B-Instruct | 1 | 745.03 | 5.82 | 19.57 | 0.7066 | 95.28 | +| Meta-Llama-3-8B-Instruct | 2 | 372.20 | 5.57 | 49.10 | 0.7089 | 95.24 | +| Meta-Llama-3-8B-Instruct | 4 | 264.07 | 5.82 | 52.50 | 0.7180 | 96.74 | +| Qwen3-30B-A3B | 1 | 14207.53 | 6.56 | 748.23 | 0.8704 | 209.93 | +| Qwen3-30B-A3B | 2 | 7018.25 | 6.36 | 696.65 | 0.8810 | 205.89 | +| Qwen3-30B-A3B | 4 | 3694.46 | 6.36 | 723.05 | 0.8832 | 217.62 | \ No newline at end of file diff --git a/docs/guides/memory.md b/docs/guides/memory.md index e736de80a5..48a942121c 100644 --- a/docs/guides/memory.md +++ b/docs/guides/memory.md @@ -17,7 +17,7 @@ Also, larger models, like DeepSeek R1 use a large amount of CPU memory, and mode 2. How text decoder layers and vision tower layers are loaded on to GPU differs significantly. - In the case of text decoder layers, LLM Compressor dynamically loads one layer at a time into the GPU for computation. The rest of the model remains in CPU memory. + In the case of text decoder layers, LLM Compressor typically loads one layer at a time into the GPU for computation, while the rest remains offloaded in CPU/Disk memory. For more information, see [Sequential Onloading](./sequential_onloading.md). However, vision tower layers are loaded onto GPU all at once. Unlike the text model, vision towers are not split up into individual layers before onloading to the GPU. This can create a GPU memory bottleneck for models whose vision towers are larger than their text layers. diff --git a/docs/guides/model_loading.md b/docs/guides/model_loading.md new file mode 100644 index 0000000000..9a3574228e --- /dev/null +++ b/docs/guides/model_loading.md @@ -0,0 +1,71 @@ +# Model Loading # +LLM Compressor utilizes the [Compressed Tensors](https://github.com/vllm-project/compressed-tensors) library to handle model offloading. In nearly all cases, it is recommended to compress your model using the [sequential pipeline](./sequential_onloading.md), which enables the quantization of large models without requiring significant VRAM. + +!!! tip + For more information on when to use the *basic* pipeline rather than the *sequential* pipeline, see [Basic Pipeline](./model_loading.md#basic-pipeline). In these cases, it is recommended to load your model onto GPU first, rather than CPU/Disk. + +Loading your model directly onto CPU is simple using `transformers`: + +```python +# model is on cpu +model = AutoModelForCausalLM.from_pretrained(model_stub, dtype="auto") +``` + +However, there are some exceptions when it is required to change this logic to handle more advanced loading. The table below shows the behavior of different model loading configurations. + +Distributed=False | device_map="auto" | device_map="cuda" | device_map="cpu" | device_map="auto_offload" +-- | -- | -- | -- | -- +`load_offloaded_model` context required? | No | No | No | Yes +Behavior | Try to load model onto all visible cuda devices. Fallback to cpu and disk if model too large | Try to load model onto first cuda device only. Error if model is too large | Try to load model onto cpu. Error if the model is too large | Try to load model onto cpu. Fallback to disk if model is too large +LLM Compressor Examples | This is the recommended load option when using the "basic" pipeline |   |   | This is the recommended load option when using the "sequential" pipeline + +Distributed=True | device_map="auto" | device_map="cuda" | device_map="cpu" | device_map="auto_offload" +-- | -- | -- | -- | -- +`load_offloaded_model` context required? | Yes | Yes | Yes | Yes +Behavior | Try to load model onto device 0, then broadcast replicas to other devices. Fallback to cpu and disk if model is too large | Try to load model onto device 0 only, then broadcast replicas to other devices. Error if model is too large | Try to load model onto cpu. Error if the model is too large | Try to load model onto cpu. Fallback to disk if model is too large +LLM Compressor Examples | This is the recommended load option when using the "basic" pipeline |   |   | This is the recommended load option when using the "sequential" pipeline + +## Disk Offloading ## +When compressing models which are larger than the available CPU memory, it is recommended to utilize disk offloading for any weights which cannot fit on the cpu. To enable disk offloading, use the `load_offloaded_model` context from `compressed_tensors` to load your model, along with `device_map="auto_offload"`. + +```python +from compressed_tensors.offload import load_offloaded_model + +with load_offloaded_model(): + model_id = "Qwen/Qwen3-0.6B" + model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype="auto", + device_map="auto_offload", # fit as much as possible on cpu, rest goes on disk + max_memory={"cpu": 6e8}, # optional argument to specify how much cpu memory to use + offload_folder="./offload_folder", # file system with lots of storage + ) +``` + +In order to specify where disk-offloaded weights should be stored, please specify the `offload_folder` argument. + +You can then call `oneshot` as usual to perform calibration and compression. Some operations may be slower due to disk offloading. + +## Distributed Oneshot ## +When performing `oneshot` with distributed computing, you will need to ensure that your model does not replicate offloaded values across ranks, otherwise this will create excess work and memory usage. Coordinated loading between ranks is automatically handled by the `load_offloaded_model` context, so long as it is entered after `torch.distributed` has been initialized. + +```python +from compressed_tensors.offload import init_dist, load_offloaded_model + +init_dist() +with load_offloaded_model(): + model = AutoModelForCausalLM.from_pretrained( + model_id, dtype="auto", device_map="auto_offload" + ) +``` + +## Basic Pipeline ## +It is recommended to only use the basic pipeline when your model is small enough to fit into the available VRAM, including any auxillary memory requirements of algorithms such as GPTQ hessians. The basic pipeline can provide compression runtime speedups when compared to the sequential pipeline. + +In these cases, you can load the model directly onto your GPU devices, and call oneshot with the relevant argument. + +```python +model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="auto") # model is on devices +... +oneshot(model, ..., pipeline="basic") +``` \ No newline at end of file diff --git a/docs/guides/sequential_onloading.md b/docs/guides/sequential_onloading.md new file mode 100644 index 0000000000..8be6b9d6af --- /dev/null +++ b/docs/guides/sequential_onloading.md @@ -0,0 +1,42 @@ +# Sequential Onloading # + +## Introduction ## + +LLM Compressor is capable of compressing models much larger than the amount of memory available as VRAM. This is achieved through a technique called **sequential onloading** whereby only a fraction of the model weights are moved to GPU memory for calibration while the rest of the weights remain offloaded to CPU or disk. When performing calibration, the entire dataset is offloaded to CPU, then onloaded one batch at a time to reduce peak activations memory usage. + +![sequential_onloading](../assets/sequential_onloading.jpg) + +If basic calibration/inference is represented with the following pseudo code... +```python +for i in range(len(activations)): + for layer in model.layers: + activations[i] = layer(activations[i]) +``` + +Then sequential onloading is the technique by which the order of the two for loops is swapped. +```python +for layer in model.layers: + for i in range(len(activations)): + dataset[i] = layer(dataset[i]) +``` + +## Implementation ## + +Before a model can be sequentially onloaded, it must first be broken up into disjoint parts which can be individually onloaded. This is achieved through the [torch.fx.Tracer](https://github.com/pytorch/pytorch/blob/main/torch/fx/README.md#tracing) module, which allows a model to be represented as a graph of operations (nodes) and data inputs (edges). Once the model has been traced into a valid graph representation, the graph is cut (partitioned) into disjoint subgraphs, each of which is onloaded individually as a layer. This implementation can be found [here](/src/llmcompressor/pipelines/sequential/helpers.py). + +![sequential_onloading](../assets/model_graph.jpg) +*This image depicts some of the operations performed when executing the Llama3.2-Vision model* + +![sequential_onloading](../assets/sequential_decoder_layers.jpg) +*This image depicts the sequential text decoder layers of the Llama3.2-Vision model. Each of the individual decoder layers is onloaded separately* + +## Sequential Targets and Usage ## +You can use sequential onloading by calling `oneshot` with the `pipeline="sequential"` argument. Note that this pipeline is the default for all oneshot calls which require calibration data. If the sequential pipeline proves to be problematic, you can specify `pipeline="basic"` to use a basic pipeline which does not require sequential onloading, but only works performantly when the model is small enough to fit into the available VRAM. + +If you are compressing a model using a GPU with a small amount of memory, you may need to change your sequential targets. Sequential targets control how many weights to onload to the GPU at a time. By default, the sequential targets are decoder layers which may include large MoE layers. In these cases, setting the `sequential_targets="Linear"` argument in `oneshot` will result in lower VRAM usage, but a longer runtime. + +![sequential_onloading](../assets/seq_targets.jpg) + +## More information ## + +For more information, see the [RedHat AI blog post](https://developers.redhat.com/articles/2025/05/09/llm-compressor-optimize-llms-low-latency-deployments#generalizing_to_multimodal_and_moe_architectures) or the [LLM Compressor Office Hours Recording](https://www.youtube.com/watch?v=GrhuqQDmBk8). \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index e4d6879e4a..76f196415d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ LLM Compressor Flow

-## What challenges does LLM Compressor address? +## Which challenges does LLM Compressor address? Model optimization through quantization and pruning addresses the key challenges of deploying AI at scale: diff --git a/examples/disk_offloading/README.md b/examples/disk_offloading/README.md new file mode 100644 index 0000000000..5f465f6ad5 --- /dev/null +++ b/examples/disk_offloading/README.md @@ -0,0 +1,2 @@ +## Disk Offloading ## +For more information on disk offloading, see [Model Loading](/docs/guides/model_loading.md). \ No newline at end of file