-
Notifications
You must be signed in to change notification settings - Fork 200
[WIP] [MoE] GPT OSS #1705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kylesayrs
wants to merge
3
commits into
main
Choose a base branch
from
kylesayrs/gpt_oss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+406
−6
Draft
[WIP] [MoE] GPT OSS #1705
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import GPTQModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
# Select model and load it. | ||
model_id = "unsloth/gpt-oss-20b-BF16" | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# * quantize the weights to 4 bit with GPTQ with a group size 128 | ||
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
dataset=ds, | ||
recipe=recipe, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
calibrate_moe_context=True, | ||
pipeline="sequential", | ||
) | ||
|
||
# # Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
sample = tokenizer("Hello my name is", return_tensors="pt") | ||
sample = {key: value.to("cuda") for key, value in sample.items()} | ||
output = model.generate(**sample, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
# flake8: noqa | ||
from typing import List | ||
|
||
import torch | ||
from compressed_tensors.utils import align_module_device, update_offload_parameter | ||
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig | ||
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts | ||
|
||
from llmcompressor.utils.dev import skip_weights_initialize | ||
|
||
|
||
class GptOssExpert(torch.nn.Module): | ||
gate_proj: torch.nn.Linear | ||
up_proj: torch.nn.Linear | ||
down_proj: torch.nn.Linear | ||
|
||
def __init__(self, experts: GptOssExperts): | ||
super().__init__() | ||
|
||
self.hidden_size = experts.hidden_size | ||
self.expert_dim = experts.expert_dim | ||
self.alpha = experts.alpha | ||
self.limit = experts.limit | ||
|
||
assert experts.gate_up_proj.dtype == experts.gate_up_proj_bias.dtype | ||
assert experts.down_proj.dtype == experts.down_proj_bias.dtype | ||
|
||
with skip_weights_initialize(): | ||
self.gate_proj = torch.nn.Linear( | ||
self.hidden_size, | ||
self.expert_dim, | ||
bias=True, | ||
dtype=experts.gate_up_proj.dtype, | ||
) | ||
self.up_proj = torch.nn.Linear( | ||
self.hidden_size, | ||
self.expert_dim, | ||
bias=True, | ||
dtype=experts.gate_up_proj.dtype, | ||
) | ||
self.down_proj = torch.nn.Linear( | ||
self.expert_dim, | ||
self.hidden_size, | ||
bias=True, | ||
dtype=experts.down_proj.dtype, | ||
) | ||
|
||
def forward(self, hidden_states: torch.Tensor): | ||
gate = self.gate_proj(hidden_states) | ||
gate = gate.clamp(min=None, max=self.limit) | ||
|
||
up = self.up_proj(hidden_states) | ||
up = up.clamp(min=-self.limit, max=self.limit) | ||
|
||
glu = gate * torch.sigmoid(gate * self.alpha) | ||
return self.down_proj((up + 1) * glu) | ||
|
||
|
||
class GptOssExpertsLinear(torch.nn.Module): | ||
experts: List[GptOssExpert] | ||
|
||
def __init__(self, experts: GptOssExperts): | ||
super().__init__() | ||
|
||
self.intermediate_size = experts.intermediate_size | ||
self.num_experts = experts.num_experts | ||
self.hidden_size = experts.hidden_size | ||
self.expert_dim = experts.expert_dim | ||
|
||
with skip_weights_initialize(): | ||
self.experts = torch.nn.ModuleList( | ||
[GptOssExpert(experts) for _ in range(self.num_experts)] | ||
) | ||
|
||
self.load_weights(experts) | ||
|
||
self.alpha = experts.alpha | ||
self.limit = experts.limit | ||
|
||
def load_weights(self, experts: GptOssExperts): | ||
with align_module_device(experts): | ||
for expert_index, expert in enumerate(self.experts): | ||
update_offload_parameter( | ||
expert.gate_proj, | ||
"weight", | ||
experts.gate_up_proj[expert_index, ..., ::2].T, | ||
) | ||
update_offload_parameter( | ||
expert.gate_proj, | ||
"bias", | ||
experts.gate_up_proj_bias[expert_index, ..., ::2], | ||
) | ||
|
||
update_offload_parameter( | ||
expert.up_proj, | ||
"weight", | ||
experts.gate_up_proj[expert_index, ..., 1::2].T, | ||
) | ||
update_offload_parameter( | ||
expert.up_proj, | ||
"bias", | ||
experts.gate_up_proj_bias[expert_index, ..., 1::2], | ||
) | ||
|
||
update_offload_parameter( | ||
expert.down_proj, "weight", experts.down_proj[expert_index].T | ||
) | ||
update_offload_parameter( | ||
expert.down_proj, "bias", experts.down_proj_bias[expert_index] | ||
) | ||
|
||
def to_original(self) -> GptOssExperts: | ||
# TODO: this doesn't really handle offloading or correct device placement | ||
with skip_weights_initialize(use_zeros=True): | ||
fake_config = GptOssConfig( | ||
intermediate_size=self.intermediate_size, | ||
num_local_experts=self.num_experts, | ||
hidden_size=self.hidden_size, | ||
) | ||
experts = GptOssExperts(fake_config) | ||
experts.gate_up_proj = torch.nn.Parameter( | ||
experts.gate_up_proj.to(dtype=self.experts[0].gate_proj.weight.dtype), | ||
requires_grad=False, | ||
) | ||
experts.gate_up_proj_bias = torch.nn.Parameter( | ||
experts.gate_up_proj_bias.to( | ||
dtype=self.experts[0].gate_proj.weight.dtype | ||
), | ||
requires_grad=False, | ||
) | ||
experts.down_proj = torch.nn.Parameter( | ||
experts.down_proj.to(dtype=self.experts[0].down_proj.weight.dtype), | ||
requires_grad=False, | ||
) | ||
experts.down_proj_bias = torch.nn.Parameter( | ||
experts.down_proj_bias.to(dtype=self.experts[0].down_proj.weight.dtype), | ||
requires_grad=False, | ||
) | ||
|
||
for expert_index, expert in enumerate(self.experts): | ||
with align_module_device(expert.gate_proj, "cpu"), align_module_device( | ||
expert.up_proj, "cpu" | ||
), align_module_device(expert.down_proj, "cpu"): | ||
experts.gate_up_proj[expert_index, ..., ::2].copy_( | ||
expert.gate_proj.weight.data.T | ||
) | ||
experts.gate_up_proj_bias[expert_index, ..., ::2].copy_( | ||
expert.gate_proj.bias.data | ||
) | ||
|
||
experts.gate_up_proj[expert_index, ..., 1::2].copy_( | ||
expert.up_proj.weight.data.T | ||
) | ||
experts.gate_up_proj_bias[expert_index, ..., 1::2].copy_( | ||
expert.up_proj.bias.data | ||
) | ||
|
||
experts.down_proj[expert_index].copy_(expert.down_proj.weight.data.T) | ||
experts.down_proj_bias[expert_index].copy_(expert.down_proj.bias.data) | ||
|
||
# TODO: convert qparams as well | ||
|
||
print("converted, for some reason slows down over time") | ||
import time | ||
|
||
print(time.time()) | ||
|
||
experts.eval() | ||
return experts | ||
|
||
def forward( | ||
self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None | ||
) -> torch.Tensor: | ||
""" | ||
When training is is more efficient to just loop over the experts and compute the output for each expert | ||
as otherwise the memory would explode. | ||
|
||
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. | ||
|
||
Args: | ||
hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) | ||
selected_experts (torch.Tensor): (batch_size * token_num, top_k) | ||
routing_weights (torch.Tensor): (batch_size * token_num, num_experts) | ||
Returns: | ||
torch.Tensor | ||
""" | ||
original_shape = hidden_states.shape | ||
hidden_states = hidden_states.reshape( | ||
-1, self.hidden_size | ||
) # (num_tokens, hidden_size) | ||
|
||
next_states = torch.zeros_like( | ||
hidden_states, dtype=hidden_states.dtype, device=hidden_states.device | ||
) | ||
for expert_index, expert in enumerate(self.experts): | ||
next_states += expert(hidden_states) * routing_weights.T[ | ||
expert_index | ||
].unsqueeze(-1) | ||
|
||
next_states = next_states.reshape(original_shape) | ||
return next_states | ||
|
||
|
||
def replace_gpt_oss(config: GptOssConfig, module: GptOssExpert): | ||
return GptOssExpertsLinear(module) | ||
|
||
|
||
def test_restore(): | ||
config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) | ||
|
||
original = GptOssExperts(config) | ||
linear = GptOssExpertsLinear(original) | ||
|
||
restored = linear.to_original() | ||
for param_name, param in original.named_parameters(recurse=False): | ||
restored_param = getattr(restored, param_name) | ||
assert param.shape == restored_param.shape | ||
assert param.dtype == restored_param.dtype | ||
|
||
assert torch.all(getattr(restored, param_name) == param) | ||
|
||
|
||
def test_correctness(): | ||
batch_size, seq_len = 13, 12 | ||
config = GptOssConfig(hidden_size=7, num_local_experts=3, expert_dim=5) | ||
|
||
input = torch.rand((batch_size, seq_len, config.hidden_size)) | ||
routing_weights = torch.rand((batch_size * seq_len, config.num_local_experts)) | ||
|
||
with torch.no_grad(): | ||
original = GptOssExperts(config) | ||
for name in [ | ||
"gate_up_proj", | ||
"gate_up_proj_bias", | ||
"down_proj", | ||
"down_proj_bias", | ||
]: | ||
setattr(original, name, getattr(original, name).normal_()) | ||
|
||
original.eval() | ||
assert not original.training | ||
true_output = original(input, routing_weights=routing_weights) | ||
|
||
linear = GptOssExpertsLinear(original) | ||
output = linear(input, routing_weights=routing_weights) | ||
|
||
assert torch.allclose(output, true_output, atol=1e-3, rtol=0.0) | ||
|
||
restored = linear.to_original() | ||
restored_output = restored(input, routing_weights=routing_weights) | ||
assert torch.allclose(restored_output, true_output, atol=1e-3, rtol=0.0) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_restore() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.