Skip to content

Commit f5c2150

Browse files
committed
R1 working
Signed-off-by: Kyle Sayers <[email protected]>
1 parent bd7f4d5 commit f5c2150

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

33
from .prepare import *
4+
from .fuse import *

src/llmcompressor/modeling/fuse.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Iterable
22

33
import torch
4-
from compressed_tensors import update_offload_parameter
4+
from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter
5+
6+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
57

68
__all__ = ["fuse_norm_linears"]
79

@@ -16,10 +18,13 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
1618
:param norm: norm layer whose weight will be fused into subsequent linears
1719
:param linears: linear layers which directly follow the norm layer
1820
"""
19-
if isinstance(norm, torch.nn.RMSNorm):
21+
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)):
2022
for linear in linears:
21-
# spinquant does this op in float64
22-
new_weight = linear.weight * norm.weight
23+
# NOTE: spinquant does this op in float64
24+
exec_device = get_execution_device(norm)
25+
with align_module_device(norm, exec_device), align_module_device(linear, exec_device):
26+
new_weight = linear.weight * norm.weight
27+
2328
update_offload_parameter(linear, "weight", new_weight)
2429

2530
update_offload_parameter(norm, "weight", torch.ones_like(norm.weight))

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import ValidationError, model_validator
55

66
from llmcompressor.core import Event, EventType, State
7+
from llmcompressor.modeling import fuse_norm_linears
78
from llmcompressor.modifiers import Modifier
89
from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS
910

@@ -29,13 +30,19 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier":
2930
return model
3031

3132
def on_initialize(self, state: State, **kwargs) -> bool:
32-
apply_transform_config(state.model, self.config)
33-
3433
return True
3534

3635
def on_start(self, state: State, event: Event, **kwargs):
3736
self.started_ = True
3837

38+
for layer in state.model.model.layers:
39+
fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj))
40+
fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj))
41+
42+
# needs to happen after the model has been hooked to execute on the GPU
43+
# otherwise we're applying weight transforms on CPU
44+
apply_transform_config(state.model, self.config)
45+
3946
def on_event(self, state: State, event: Event, **kwargs):
4047
if event.type_ == EventType.CALIBRATION_EPOCH_START:
4148
if not self.started_:

src/llmcompressor/pipelines/data_free/pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from llmcompressor.core.session_functions import LifecycleCallbacks
77
from llmcompressor.pipelines.registry import CalibrationPipeline
8+
from llmcompressor.utils.dev import dispatch_for_generation
89

910
if TYPE_CHECKING:
1011
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -27,5 +28,9 @@ def __call__(
2728
:param dataloader: loads data for calibration
2829
:param dataset_args: dataset arguments relevant to pipelines
2930
"""
31+
# some ops are still performed on the model by modifiers
32+
# we want those ops to occur on the GPU
33+
dispatch_for_generation(model)
34+
3035
LifecycleCallbacks.calibration_epoch_start()
3136
LifecycleCallbacks.calibration_epoch_end()

0 commit comments

Comments
 (0)