Skip to content

Commit 953ab34

Browse files
ochougulqcdipankar
andcommitted
Add fp8 support (#802)
Signed-off-by: Dipankar Sarkar <quic_dipankar@quicinc.com> Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com> Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com> Co-authored-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 7d68604 commit 953ab34

File tree

10 files changed

+1770
-113
lines changed

10 files changed

+1770
-113
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ def _get_invalid_idx_value(cls):
5555

5656

5757
class QEffDynamicLayer(DynamicLayer):
58+
def lazy_initialization(self, key_states: torch.Tensor):
59+
self.dtype, self.device = key_states.dtype, key_states.device
60+
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
61+
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
62+
self.is_initialized = True
63+
5864
def read_only(self, cache_kwargs):
5965
"""
6066
Reads the `key_states` and `value_states` for the layer.
@@ -151,6 +157,7 @@ def write_only(self, key_states, value_states, cache_kwargs):
151157
self.keys = key_states
152158
self.values = value_states
153159
else:
160+
# breakpoint()
154161
position_ids = cache_kwargs.get("position_ids")
155162
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
156163

@@ -185,11 +192,15 @@ def update(
185192
Return:
186193
A tuple containing the updated key and value states.
187194
"""
195+
# breakpoint()
188196
# Update the cache
197+
# if not self.is_initialized:
198+
189199
if self.keys is None:
190200
self.keys = key_states
191201
self.values = value_states
192202
k_out, v_out = self.keys, self.values
203+
self.is_initialized = True
193204
else:
194205
position_ids = cache_kwargs.get("position_ids")
195206
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
@@ -306,15 +317,48 @@ class QEffDynamicCache(DynamicCache):
306317
307318
"""
308319

309-
def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
320+
def __init__(
321+
self,
322+
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
323+
config=None,
324+
offloading: bool = False,
325+
offload_only_non_sliding: bool = False,
326+
*args,
327+
**kwargs,
328+
):
310329
# Remove layer_classes if present to avoid duplicate argument
311-
kwargs.pop("layer_classes", None)
330+
# breakpoint()
331+
kwargs.pop("layers", None)
312332
from transformers.cache_utils import Cache # Import here to avoid circular import
313333

314-
Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
334+
# breakpoint()
335+
layers = []
336+
# If a config is passed, use it to infer the layer types and initialize accordingly
337+
if len(layers) == 0:
338+
Cache.__init__(
339+
self,
340+
layer_class_to_replicate=QEffDynamicLayer,
341+
offloading=offloading,
342+
offload_only_non_sliding=offload_only_non_sliding,
343+
# args=args,
344+
# kwargs=kwargs,
345+
)
346+
else:
347+
Cache.__init__(
348+
self,
349+
layers=layers,
350+
offloading=offloading,
351+
offload_only_non_sliding=offload_only_non_sliding,
352+
# args=args,
353+
# kwargs=kwargs,
354+
)
355+
315356
if ddp_cache_data is not None:
316-
for key_states, value_states in ddp_cache_data:
317-
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))
357+
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
358+
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
359+
layers.append(QEffDynamicLayer())
360+
# Update the layer with the data
361+
_, _ = layers[layer_idx].update(key_states, value_states)
318362

319363
def read_only(self, layer_idx, cache_kwargs):
320364
"""
@@ -329,6 +373,7 @@ def read_only(self, layer_idx, cache_kwargs):
329373
Return:
330374
A tuple containing the updated key and value states.
331375
"""
376+
# breakpoint()
332377
return self.layers[layer_idx].read_only(cache_kwargs)
333378

334379
def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
@@ -394,6 +439,18 @@ def update3D(
394439
self.append_new_layers(layer_idx)
395440
return self.layers[layer_idx].update3D(key_states, value_states, cache_kwargs)
396441

442+
# def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
443+
# """Returns the sequence length of the cached states. A layer index can be optionally passed."""
444+
# # TODO: deprecate this function in favor of `cache_position`
445+
# breakpoint()
446+
# is_empty_layer = (
447+
# len(self.key_cache) == 0 # no cache in any layer
448+
# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
449+
# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
450+
# )
451+
# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
452+
# return layer_seq_length
453+
397454

398455
class QEffEncoderDecoderCache(EncoderDecoderCache):
399456
"""

QEfficient/transformers/models/modeling_auto.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers
6363
from QEfficient.transformers.quantizers.quant_transforms import (
6464
AwqToMatmulNbitsTransform,
65+
FP8BlockWiseDequantLinearToLinearTransform,
66+
FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform,
6567
FP8DeQuantLinearToLinearTransform,
6668
GPTQToMatmulNbitsTransform,
6769
Mxfp4GptOssExpertDequantizeTransform,
@@ -964,6 +966,8 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
964966
_pytorch_transforms = [
965967
AwqToMatmulNbitsTransform,
966968
GPTQToMatmulNbitsTransform,
969+
FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform,
970+
FP8BlockWiseDequantLinearToLinearTransform,
967971
CustomOpsTransform,
968972
KVCacheTransform,
969973
VlmKVOffloadTransform,
@@ -1618,6 +1622,7 @@ def kv_offload_generate(
16181622
AssertionError
16191623
If `generation_len` is not greater than zero.
16201624
"""
1625+
# breakpoint()
16211626
if not self.lang_model.qpc_path:
16221627
raise TypeError("Please run compile API for language model first!")
16231628

@@ -1649,7 +1654,7 @@ def kv_offload_generate(
16491654
[x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes]
16501655
+ [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]]
16511656
)
1652-
1657+
# breakpoint()
16531658
input_len = inputs["attention_mask"].sum(1, keepdims=True)
16541659
input_ids_length = inputs["input_ids"].shape[1]
16551660
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
@@ -1695,7 +1700,7 @@ def kv_offload_generate(
16951700
vision_end = perf_counter()
16961701

16971702
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
1698-
1703+
# breakpoint()
16991704
if "position_ids" in inputs:
17001705
lang_inputs["position_ids"] = inputs["position_ids"]
17011706
lang_inputs.pop("attention_mask")
@@ -1707,7 +1712,7 @@ def kv_offload_generate(
17071712
not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama"
17081713
if not_mllama:
17091714
lang_inputs["image_idx"] = np.array([[0]])
1710-
1715+
# breakpoint()
17111716
if self.vision_model.qpc_path:
17121717
vision_session.deactivate()
17131718
lang_session.activate()
@@ -1722,7 +1727,7 @@ def kv_offload_generate(
17221727
lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
17231728

17241729
lang_start = perf_counter()
1725-
1730+
# breakpoint()
17261731
# Run prefill
17271732
chunk_inputs = lang_inputs.copy()
17281733
for i in range(num_chunks):
@@ -1751,7 +1756,7 @@ def kv_offload_generate(
17511756
)
17521757
if not_mllama:
17531758
lang_session.skip_buffers(vision_outputs.keys())
1754-
1759+
# breakpoint()
17551760
# Get first token
17561761
lang_inputs["input_ids"] = outputs["logits"].argmax(2)
17571762
lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1

0 commit comments

Comments
 (0)