Skip to content
Open

MLA #789

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,23 @@ def _export(
raise ValueError(
f"Unknown shape of past_key_values! Expected length of past_key_values for each layer to be either 2 or 4 but got {len(example_inputs['past_key_values'][0])}"
)
elif param == "compressed_kvs":
for i in range(len(example_inputs["compressed_kvs"])):
# input_names.extend([f"compressed_kvs.{i}",])
input_names.extend(
[
f"compressed_kv.{i}",
]
)
input_names.extend(
[
f"k_pe.{i}",
]
)
else:
input_names.append(param)

# import ipdb; ipdb.set_trace()
try:
torch.onnx.export(
self.model,
Expand Down Expand Up @@ -329,11 +343,15 @@ def get_onnx_path(
offload_pt_weights: Optional[bool] = True,
use_onnx_subfunctions: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
enable_mla: Optional[bool] = False,
mla_absorption_config: Optional[bool] = False,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
"enable_mla": enable_mla,
"mla_absorption_config": mla_absorption_config,
}

if prefill_only:
Expand Down Expand Up @@ -366,6 +384,8 @@ def _compile(
offload_pt_weights: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = None,
enable_mla: Optional[bool] = False,
mla_absorption_config: Optional[Dict[str, bool]] = False,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -403,6 +423,8 @@ def _compile(
offload_pt_weights,
use_onnx_subfunctions,
retain_full_kv,
enable_mla,
mla_absorption_config,
)
)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
12 changes: 10 additions & 2 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# ----------------------------------------------------------------------------
from types import MethodType
from typing import Callable, Dict, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type

from torch import nn

Expand Down Expand Up @@ -70,6 +70,7 @@ class ModuleMutatorTransform(PytorchTransform):
"""

_match_class: nn.Module
_match_string: Optional[str] = None

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
Expand Down Expand Up @@ -108,7 +109,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
parts = orig_method_name.split(".")
if len(parts) > 1:
target = module
for part in parts[:-1]:
target = getattr(target, part)
setattr(target, parts[-1], MethodType(mapped_method, target))
else:
setattr(module, orig_method_name, MethodType(mapped_method, module))

if hasattr(module, "__qeff_init__"):
module.__qeff_init__()
Expand Down
61 changes: 55 additions & 6 deletions QEfficient/customop/matmulnbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def symbolic(g, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group
@staticmethod
def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, group_size, in_features, out_features):
if torch.onnx.is_in_onnx_export():
# For faster export
return torch.zeros(x.shape[:-1] + (out_features,), dtype=x.dtype).float()
fp_weight = dequantize_blockwise_bits(
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
Expand All @@ -40,8 +41,7 @@ def forward(ctx, x, qself_qweight, qself_scales, qself_qzeros, g_idx, bits, grou
def dequantize_blockwise_bits(quant_values, scale, zero_point, bits, group_size, g_idx, rows, cols):
if bits != 4:
raise ValueError("Only bits=4 is supported for executing quantized model")
if group_size != 128:
raise ValueError("Only group_size=128 is supported for executing quantized model")

expand_quant_value = (quant_values.unsqueeze(-1) >> torch.tensor([[[[0, 4]]]], dtype=torch.int32)) & 0x0F
expand_quant_value = expand_quant_value.reshape(*quant_values.shape[:-1], -1)
aligned_scale = scale.reshape(*quant_values.shape[:-1], 1)
Expand Down Expand Up @@ -88,20 +88,20 @@ def __init__(self, bits, group_size, in_features, out_features, bias):
q_rows = in_features // self.group_size
self.register_buffer(
"qweight",
torch.zeros((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
torch.empty((out_features, q_rows, self.group_size // (8 // bits)), dtype=torch.uint8),
)
self.register_buffer(
"qzeros",
torch.zeros((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
torch.empty((q_rows + (q_rows & 1)) * (out_features // 8 * self.bits), dtype=torch.uint8),
)
self.register_buffer(
"scales", torch.zeros((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
"scales", torch.empty((math.ceil(in_features / self.group_size) * out_features), dtype=torch.float16)
)
self.register_buffer(
"g_idx", torch.tensor([i // self.group_size for i in range(in_features)], dtype=torch.int32)
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=torch.float16))
self.register_buffer("bias", torch.empty((out_features), dtype=torch.float16))
else:
self.bias = None

Expand Down Expand Up @@ -180,3 +180,52 @@ def forward(self, inputs):
)
out = out + self.bias if self.bias is not None else out
return out


class QMOE(torch.autograd.Function):
@staticmethod
def symbolic(g, x, router_probs, fc1_experts_weights, fc1_scales, fc2_experts_weights, fc2_scales, fc3_experts_weights, fc3_scales,
activation_type, block_size, expert_weight_bits, k):
return g.op(
"com.microsoft::QMOE",
input = x,
router_probs=router_probs,
fc1_experts_weights=fc1_experts_weights,
fc1_scales=fc1_scales,
fc2_experts_weights=fc2_experts_weights,
fc2_scales=fc2_scales,
fc3_experts_weights=fc3_experts_weights,
fc3_scales=fc3_scales,
outputs=1,
activation_type_i=activation_type,
block_size_i=block_size,
expert_weight_bits_i=expert_weight_bits,
k_i=k,
)

@staticmethod
def forward(ctx, x, router_probs, fc1_experts_weights, fc1_scales, fc2_experts_weights, fc2_scales, fc3_experts_weights, fc3_scales,
activation_type, block_size, expert_weight_bits, k):
if torch.onnx.is_in_onnx_export():
return torch.zeros_like(x)
# TODO write code to gather required ones and dequantize and matmul
_, topk_idx = torch.topk(router_probs, k=self.top_k, dim=-1, sorted=False)
topk_weight = scores.gather(1, topk_idx)
expert_in = (
hidden_states.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous().view(-1, 1, self.config.hidden_size)
)
gate_out = torch.bmm(expert_in, gate_proj)
up_out = torch.bmm(expert_in, up_proj)
hidden = self.act_fn(gate_out) * up_out
expert_output = torch.bmm(hidden, down_proj)
experts_out = expert_output.view(seq_len, self.gate.top_k, self.config.hidden_size)
experts_out = experts_out * topk_weights.unsqueeze(-1)
# final_hidden_states = experts_out.sum(dim=1)
final_hidden_states = torch.einsum("abc->ac", experts_out)

return final_hidden_states.type(hidden_states.dtype)
fp_weight = dequantize_blockwise_bits(
qself_qweight, qself_scales, qself_qzeros, bits, group_size, g_idx, in_features, out_features
)[0].float()

return torch.matmul(x.float(), fp_weight.T.float())
7 changes: 7 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ def __init__(
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
)
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("compressed_")]
)
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("k_pe")]
)

def _set_tokenizer_params(self):
"""
Expand Down Expand Up @@ -840,6 +846,7 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
]
if self.include_sampler:
chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"]

outputs = self._session.run(chunk_inputs)

if self._write_io_dir is not None:
Expand Down
82 changes: 82 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,88 @@ def update3D(
return k_out, v_out


class QEffDynamicCompressedKVRopeLayer:
def __init__(self, ckv, k_pe):
self.ckv = ckv
self.k_pe = k_pe

def update_ckv(self, compressed_kv, cache_kwargs):
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later

self.ckv = CtxScatterFunc3D.apply(self.ckv, position_ids, compressed_kv)

ckv_out = self.ckv
ctx_len = ckv_out.shape[1]
ctx_indices = torch.arange(ctx_len)[None, ...]
gather_limit = position_ids.max(1, keepdim=True).values
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

ckv_out = CtxGatherFunc3D.apply(ckv_out, ctx_indices)
ckv_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), ckv_out)
return ckv_out

def update_k_pe(self, k_pe_cache, cache_kwargs):
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # TODO: add support later

self.k_pe = CtxScatterFunc.apply(self.k_pe, position_ids, k_pe_cache)

k_pe_out = self.k_pe
ctx_len = k_pe_out.shape[-2]
ctx_indices = torch.arange(ctx_len)[None, ...]
gather_limit = position_ids.max(1, keepdim=True).values
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

k_pe_out = CtxGatherFunc.apply(k_pe_out, ctx_indices, ctx_len)
k_pe_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), k_pe_out)
return k_pe_out


class QEffDynamicCompressedKVRopeCache:
def __init__(
self,
):
self.layers = []

def add_new(self, ckv, k_pe, layer_idx):
self.layers.append(QEffDynamicCompressedKVRopeLayer(ckv, k_pe))

def update_ckv(self, ckv, layer_idx, cache_kwargs):
return self.layers[layer_idx].update_ckv(ckv, cache_kwargs)

def update_k_pe(self, k_pe, layer_idx, cache_kwargs):
return self.layers[layer_idx].update_k_pe(k_pe, cache_kwargs)

@classmethod
def from_legacy_cache(cls, past_key_values):
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
ckv, k_pe = past_key_values[layer_idx]
cache.add_new(ckv, k_pe, layer_idx)
return cache

def to_legacy_cache(
self,
):
legacy_cache = ()
for layer in self.layers:
x = (layer.ckv, layer.k_pe)
legacy_cache += (x,)
return legacy_cache


class QEffDynamicCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Loading