|
| 1 | +# --------------------------------------------------------------------- |
| 2 | +# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries. |
| 3 | +# SPDX-License-Identifier: BSD-3-Clause |
| 4 | +# --------------------------------------------------------------------- |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import contextlib |
| 9 | +import json |
| 10 | +import os |
| 11 | +import shutil |
| 12 | +import struct |
| 13 | +from pathlib import Path |
| 14 | +from typing import Any, cast |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import qai_hub as hub |
| 18 | +import torch |
| 19 | + |
| 20 | +from qai_hub_models.models._shared.llm.model import LLM_AIMETOnnx, LLMBase |
| 21 | +from qai_hub_models.utils.base_model import Precision |
| 22 | + |
| 23 | +with contextlib.suppress(ImportError): |
| 24 | + from transformers import PretrainedConfig |
| 25 | + |
| 26 | +GENIE_CONFIG_JSON = "genie_config.json" |
| 27 | + |
| 28 | + |
| 29 | +def _quantize_kv_cache(f: Any, encoding: Any, bw: int = 8) -> Any: |
| 30 | + def _round(x: Any) -> Any: |
| 31 | + sign = np.where(x < 0, -1, 1).astype(np.float32) |
| 32 | + return np.floor(np.abs(x) + 0.5) * sign |
| 33 | + |
| 34 | + def _quantize(f: Any, scale: Any, offset: Any, dtype: np.dtype) -> Any: |
| 35 | + q = _round(f / scale - offset) |
| 36 | + return q.clip(np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) |
| 37 | + |
| 38 | + if isinstance(encoding, list): |
| 39 | + scale, offset = encoding[0]["scale"], encoding[0]["offset"] |
| 40 | + assert encoding[0]["bitwidth"] == bw |
| 41 | + elif isinstance(encoding, dict): |
| 42 | + scale, offset = encoding["scale"][0], encoding["offset"][0] |
| 43 | + assert encoding["bw"] == bw |
| 44 | + else: |
| 45 | + raise TypeError(f"Unknown encoding format: {type(encoding)}") |
| 46 | + |
| 47 | + f = np.array(f) |
| 48 | + _BW_TO_DTYPE: dict[int, np.dtype[Any]] = { |
| 49 | + 8: np.dtype(np.uint8), |
| 50 | + 16: np.dtype(np.uint16), |
| 51 | + 32: np.dtype(np.uint32), |
| 52 | + 64: np.dtype(np.uint64), |
| 53 | + } |
| 54 | + if bw not in _BW_TO_DTYPE: |
| 55 | + raise ValueError( |
| 56 | + f"Unsupported bitwidth: {bw}. Supported: {list(_BW_TO_DTYPE.keys())}" |
| 57 | + ) |
| 58 | + bw_dtype = _BW_TO_DTYPE[bw] |
| 59 | + return _quantize(f, scale, offset, bw_dtype) |
| 60 | + |
| 61 | + |
| 62 | +def _save_kv_cache( |
| 63 | + kvcache: Any, encodings: Any, filename: str, num_layers: int = 10000 |
| 64 | +) -> None: |
| 65 | + key_value_encodings = [ |
| 66 | + [encodings[f"past_key_{layer_n}_in"], encodings[f"past_value_{layer_n}_in"]] |
| 67 | + for layer_n in range(num_layers) |
| 68 | + ] |
| 69 | + key_q = [ |
| 70 | + _quantize_kv_cache(cache[0], encoding[0]) |
| 71 | + for cache, encoding in zip(kvcache, key_value_encodings, strict=False) |
| 72 | + ] |
| 73 | + value_q = [ |
| 74 | + _quantize_kv_cache(cache[1], encoding[1]) |
| 75 | + for cache, encoding in zip(kvcache, key_value_encodings, strict=False) |
| 76 | + ] |
| 77 | + |
| 78 | + key_cache = np.concatenate(key_q) |
| 79 | + value_cache = np.concatenate(value_q) |
| 80 | + |
| 81 | + CACHE_FILE_SPEC = "IIBxHHH" |
| 82 | + CACHE_FILE_SPEC_SIZE = struct.calcsize(CACHE_FILE_SPEC) |
| 83 | + assert CACHE_FILE_SPEC_SIZE == 16 |
| 84 | + DATATYPES = [ |
| 85 | + np.uint8, |
| 86 | + np.uint16, |
| 87 | + np.uint32, |
| 88 | + np.uint64, |
| 89 | + np.int8, |
| 90 | + np.int16, |
| 91 | + np.int32, |
| 92 | + np.int64, |
| 93 | + None, |
| 94 | + np.float16, |
| 95 | + np.float32, |
| 96 | + np.float64, |
| 97 | + bool, |
| 98 | + ] |
| 99 | + |
| 100 | + _DTYPE_TO_ID = {np.dtype(t): i for i, t in enumerate(DATATYPES) if t is not None} |
| 101 | + with open(filename, "wb") as handle: |
| 102 | + dtype = _DTYPE_TO_ID.get(key_cache.dtype) |
| 103 | + if dtype is None: |
| 104 | + raise ValueError( |
| 105 | + f"Unsupported cache dtype: {key_cache.dtype}. " |
| 106 | + f"Supported: {list(_DTYPE_TO_ID.keys())}" |
| 107 | + ) |
| 108 | + n_layer, n_head, n_tok, n_kv_dim = value_cache.shape |
| 109 | + num_tensors = n_layer * 2 |
| 110 | + handle.write( |
| 111 | + struct.pack( |
| 112 | + CACHE_FILE_SPEC, num_tensors, 0xC0DE, dtype, n_head, n_kv_dim, n_tok |
| 113 | + ) |
| 114 | + ) |
| 115 | + key_cache.tofile(handle) |
| 116 | + value_cache.tofile(handle) |
| 117 | + |
| 118 | + |
| 119 | +class LLM_SSD_Base(LLMBase): |
| 120 | + """Extends LLMBase with SSD (Self Speculative Decoding) forecast support.""" |
| 121 | + |
| 122 | + def __init__( |
| 123 | + self, |
| 124 | + *args: Any, |
| 125 | + ssd_forecast_ckpt: str | os.PathLike | Path | None = None, |
| 126 | + **kwargs: Any, |
| 127 | + ) -> None: |
| 128 | + """ |
| 129 | + Parameters |
| 130 | + ---------- |
| 131 | + *args |
| 132 | + Positional arguments forwarded to LLMBase. |
| 133 | + ssd_forecast_ckpt |
| 134 | + Path to SSD forecast file. If provided, the SSD forecast token |
| 135 | + embeddings are concatenated to the model's embedding table. |
| 136 | + **kwargs |
| 137 | + Keyword arguments forwarded to LLMBase. |
| 138 | + """ |
| 139 | + super().__init__(*args, **kwargs) |
| 140 | + if ssd_forecast_ckpt is not None: |
| 141 | + ssd_param = torch.load( |
| 142 | + ssd_forecast_ckpt, map_location="cpu", weights_only=True |
| 143 | + ) |
| 144 | + ssd_forecast_embeddings = ssd_param["forecast_embedding"] |
| 145 | + if len(ssd_forecast_embeddings) >= 1: |
| 146 | + embed_table = cast(torch.nn.Embedding, self.model.model.embed_tokens) # type: ignore[union-attr, unused-ignore] |
| 147 | + assert ( |
| 148 | + embed_table.weight.shape[1] == ssd_forecast_embeddings.shape[1] |
| 149 | + ), "Mismatching token embedding size for embed_tokens" |
| 150 | + embed_table.weight.data = torch.cat( |
| 151 | + [ |
| 152 | + embed_table.weight.data, |
| 153 | + ssd_forecast_embeddings.to(embed_table.weight.dtype), |
| 154 | + ], |
| 155 | + dim=0, |
| 156 | + ) |
| 157 | + embed_table.num_embeddings = embed_table.weight.shape[0] |
| 158 | + |
| 159 | + |
| 160 | +class LLM_SSD_AIMETOnnx(LLM_AIMETOnnx): |
| 161 | + """Extends LLM_AIMETOnnx with SSD (Self Speculative Decoding) support.""" |
| 162 | + |
| 163 | + @classmethod |
| 164 | + def prepare_genie_assets( |
| 165 | + cls, |
| 166 | + hub_device: hub.Device, |
| 167 | + checkpoint: str | os.PathLike | Path, |
| 168 | + llm_config: PretrainedConfig, |
| 169 | + context_length: int, |
| 170 | + model_list: list[str], |
| 171 | + output_path: Path, |
| 172 | + precision: Precision, |
| 173 | + encodings_path: str | os.PathLike | Path, |
| 174 | + input_specs: dict[str, Any], |
| 175 | + output_specs: dict[str, Any], |
| 176 | + ) -> None: |
| 177 | + super().prepare_genie_assets( |
| 178 | + hub_device, |
| 179 | + checkpoint, |
| 180 | + llm_config, |
| 181 | + context_length, |
| 182 | + model_list, |
| 183 | + output_path, |
| 184 | + precision, |
| 185 | + encodings_path, |
| 186 | + input_specs, |
| 187 | + output_specs, |
| 188 | + ) |
| 189 | + if cls.FPModel is None or not hasattr(cls.FPModel, "_ssd_forecast_ckpt"): |
| 190 | + return |
| 191 | + ssd_forecast_ckpt = cls.FPModel._ssd_forecast_ckpt() |
| 192 | + if ssd_forecast_ckpt is None: |
| 193 | + return |
| 194 | + |
| 195 | + # Load SSD params once |
| 196 | + ssd_param = torch.load(ssd_forecast_ckpt, map_location="cpu", weights_only=True) |
| 197 | + ssd_prefix = ssd_param["forecast_prefix"].to(torch.float32) |
| 198 | + n_layer, _, _, _, len_prefix, _ = ssd_prefix.shape |
| 199 | + ssd_prefix_tuple = tuple( |
| 200 | + (ssd_prefix[i][0].permute(0, 1, 3, 2), ssd_prefix[i][1]) |
| 201 | + for i in range(n_layer) |
| 202 | + ) |
| 203 | + num_ssd_forecast_tokens = len(ssd_param["forecast_embedding"]) |
| 204 | + |
| 205 | + # Load activation_encodings (to scan for all 'past_key_*_in' layers) |
| 206 | + with open(encodings_path) as f: |
| 207 | + encodings = json.load(f) |
| 208 | + if isinstance(encodings["activation_encodings"], list): |
| 209 | + # Convert encodings to dictionary |
| 210 | + encodings["activation_encodings"] = { |
| 211 | + v["name"]: v for v in encodings["activation_encodings"] |
| 212 | + } |
| 213 | + actv_encodings = encodings["activation_encodings"] |
| 214 | + num_layers = sum( |
| 215 | + 1 |
| 216 | + for ae_key in actv_encodings |
| 217 | + if ae_key.startswith("past_value_") and ae_key.endswith("_in") |
| 218 | + ) |
| 219 | + |
| 220 | + # Create 'forecast-prefix' folder and save kvcache prefix |
| 221 | + ssd_prefix_des_dir = output_path / "forecast-prefix" |
| 222 | + shutil.rmtree(ssd_prefix_des_dir, ignore_errors=True) |
| 223 | + ssd_prefix_des_dir.mkdir(parents=True, exist_ok=True) |
| 224 | + _save_kv_cache( |
| 225 | + ssd_prefix_tuple, |
| 226 | + actv_encodings, |
| 227 | + str(ssd_prefix_des_dir / "kv-cache.primary.qnn-htp"), |
| 228 | + num_layers, |
| 229 | + ) |
| 230 | + |
| 231 | + # Update genie config with SSD params |
| 232 | + with open(output_path / GENIE_CONFIG_JSON) as f: |
| 233 | + genie_config = json.load(f) |
| 234 | + genie_config["dialog"]["type"] = "ssd-q1" |
| 235 | + genie_config["dialog"]["ssd-q1"] = { |
| 236 | + "version": 1, |
| 237 | + "ssd-version": 1, |
| 238 | + "forecast-token-count": num_ssd_forecast_tokens, |
| 239 | + "forecast-prefix": len_prefix, |
| 240 | + "forecast-prefix-name": ssd_prefix_des_dir.name, |
| 241 | + "branches": [3, 2], |
| 242 | + "n-streams": 1, |
| 243 | + "p-threshold": 0.0, |
| 244 | + } |
| 245 | + with open(output_path / GENIE_CONFIG_JSON, "w") as f: |
| 246 | + json.dump(genie_config, f, indent=4) |
0 commit comments