-
Notifications
You must be signed in to change notification settings - Fork 453
[Offloading] Support Disk Offloading #2373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
2db91e2
changes
kylesayrs 87c79d8
add examples
kylesayrs 43bdec4
clean up example
kylesayrs 5133d63
explicit offload folders
kylesayrs 550b9d5
Merge branch 'main' into kylesayrs/support-disk-offloading
kylesayrs f3a75d7
Merge branch 'main' into kylesayrs/support-disk-offloading
kylesayrs e489239
fix test
kylesayrs f53095b
revert change
kylesayrs 0005de0
fix example
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| from compressed_tensors.offload import get_device_map, load_offloaded_model | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Select model and load it in the `load_offloaded_model` context | ||
| with load_offloaded_model(): | ||
| model_id = "unsloth/Kimi-K2-Instruct-0905-BF16" | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| dtype="auto", | ||
| device_map="auto_offload", # fit as much as possible on cpu, rest goes on disk | ||
| trust_remote_code=True, | ||
| offload_folder="./offload_folder", | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | ||
|
|
||
| # Confirm that model is dispatched correctly | ||
| devices = {offloaded for _onloaded, offloaded in get_device_map(model).values()} | ||
| print(f"Model was offloaded to the following devices: {devices}") | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "ultrachat-200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 20 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to NVFP4 | ||
| recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"]) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| processor=tokenizer, | ||
| dataset=DATASET_ID, | ||
| splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from compressed_tensors.offload import ( | ||
| dispatch_model, | ||
| get_device_map, | ||
| load_offloaded_model, | ||
| ) | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
|
|
||
| # Select model and load it in the `load_offloaded_model` context | ||
| # In this example, we emulate large model quantization with disk offloading by | ||
| # restricting the theoretical size of CPU RAM to be smaller than the size of the model | ||
| with load_offloaded_model(): | ||
| model_id = "Qwen/Qwen3-0.6B" | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| dtype="auto", | ||
| device_map="auto_offload", # fit as much as possible on cpu, rest goes on disk | ||
| max_memory={"cpu": 6e8}, # remove this line to use as much cpu as possible | ||
| offload_folder="./offload_folder", | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Confirm that model is dispatched correctly | ||
| devices = {offloaded for _onloaded, offloaded in get_device_map(model).values()} | ||
| print(f"Model was offloaded to the following devices: {devices}") | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "ultrachat-200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 20 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to NVFP4 | ||
| recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"]) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=DATASET_ID, | ||
| splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_model(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-NVFP4" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.