Skip to content

Commit 075154b

Browse files
author
Shreya Jain
authored
Llama 3.2 3B Instruct SSD (#2895)
Most of the work done by @spappach_QCOM Reopening after the mirrorring.
1 parent f3d59d9 commit 075154b

File tree

30 files changed

+1460
-1
lines changed

30 files changed

+1460
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ and many more.
353353
| [Llama-v3.1-8B-Instruct](https://aihub.qualcomm.com/models/llama_v3_1_8b_instruct) | [qai_hub_models.models.llama_v3_1_8b_instruct](qai_hub_models/models/llama_v3_1_8b_instruct/README.md) |
354354
| [Llama-v3.2-1B-Instruct](https://aihub.qualcomm.com/models/llama_v3_2_1b_instruct) | [qai_hub_models.models.llama_v3_2_1b_instruct](qai_hub_models/models/llama_v3_2_1b_instruct/README.md) |
355355
| [Llama-v3.2-3B-Instruct](https://aihub.qualcomm.com/models/llama_v3_2_3b_instruct) | [qai_hub_models.models.llama_v3_2_3b_instruct](qai_hub_models/models/llama_v3_2_3b_instruct/README.md) |
356+
| [Llama-v3.2-3B-Instruct-SSD](https://aihub.qualcomm.com/models/llama_v3_2_3b_instruct_ssd) | [qai_hub_models.models.llama_v3_2_3b_instruct_ssd](qai_hub_models/models/llama_v3_2_3b_instruct_ssd/README.md) |
356357
| [Llama3-TAIDE-LX-8B-Chat-Alpha1](https://aihub.qualcomm.com/models/llama_v3_taide_8b_chat) | [qai_hub_models.models.llama_v3_taide_8b_chat](qai_hub_models/models/llama_v3_taide_8b_chat/README.md) |
357358
| [Mistral-7B-Instruct-v0.3](https://aihub.qualcomm.com/models/mistral_7b_instruct_v0_3) | [qai_hub_models.models.mistral_7b_instruct_v0_3](qai_hub_models/models/mistral_7b_instruct_v0_3/README.md) |
358359
| [Mobile-Bert-Uncased-Google](https://aihub.qualcomm.com/models/mobile_bert_uncased_google) | [qai_hub_models.models.mobile_bert_uncased_google](qai_hub_models/models/mobile_bert_uncased_google/README.md) |

qai_hub_models/models/_shared/llm/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def test_cli_default_device_select_component(
539539
skip_download: bool,
540540
skip_summary: bool,
541541
target_runtime: TargetRuntime,
542+
decode_sequence_length: int,
542543
) -> None:
543544
context_length = 4096
544545
sequence_length = 128
@@ -610,7 +611,7 @@ def test_cli_default_device_select_component(
610611
instantiation_name = (
611612
f"ar{sequence_length}_cl{context_length}"
612613
if i < parts
613-
else f"ar1_cl{context_length}"
614+
else f"ar{decode_sequence_length}_cl{context_length}"
614615
)
615616
assert (
616617
call.kwargs["name"]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# ---------------------------------------------------------------------
2+
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
3+
# SPDX-License-Identifier: BSD-3-Clause
4+
# ---------------------------------------------------------------------
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# ---------------------------------------------------------------------
2+
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
3+
# SPDX-License-Identifier: BSD-3-Clause
4+
# ---------------------------------------------------------------------
5+
6+
from __future__ import annotations
7+
8+
import contextlib
9+
import json
10+
import os
11+
import shutil
12+
import struct
13+
from pathlib import Path
14+
from typing import Any, cast
15+
16+
import numpy as np
17+
import qai_hub as hub
18+
import torch
19+
20+
from qai_hub_models.models._shared.llm.model import LLM_AIMETOnnx, LLMBase
21+
from qai_hub_models.utils.base_model import Precision
22+
23+
with contextlib.suppress(ImportError):
24+
from transformers import PretrainedConfig
25+
26+
GENIE_CONFIG_JSON = "genie_config.json"
27+
28+
29+
def _quantize_kv_cache(f: Any, encoding: Any, bw: int = 8) -> Any:
30+
def _round(x: Any) -> Any:
31+
sign = np.where(x < 0, -1, 1).astype(np.float32)
32+
return np.floor(np.abs(x) + 0.5) * sign
33+
34+
def _quantize(f: Any, scale: Any, offset: Any, dtype: np.dtype) -> Any:
35+
q = _round(f / scale - offset)
36+
return q.clip(np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
37+
38+
if isinstance(encoding, list):
39+
scale, offset = encoding[0]["scale"], encoding[0]["offset"]
40+
assert encoding[0]["bitwidth"] == bw
41+
elif isinstance(encoding, dict):
42+
scale, offset = encoding["scale"][0], encoding["offset"][0]
43+
assert encoding["bw"] == bw
44+
else:
45+
raise TypeError(f"Unknown encoding format: {type(encoding)}")
46+
47+
f = np.array(f)
48+
_BW_TO_DTYPE: dict[int, np.dtype[Any]] = {
49+
8: np.dtype(np.uint8),
50+
16: np.dtype(np.uint16),
51+
32: np.dtype(np.uint32),
52+
64: np.dtype(np.uint64),
53+
}
54+
if bw not in _BW_TO_DTYPE:
55+
raise ValueError(
56+
f"Unsupported bitwidth: {bw}. Supported: {list(_BW_TO_DTYPE.keys())}"
57+
)
58+
bw_dtype = _BW_TO_DTYPE[bw]
59+
return _quantize(f, scale, offset, bw_dtype)
60+
61+
62+
def _save_kv_cache(
63+
kvcache: Any, encodings: Any, filename: str, num_layers: int = 10000
64+
) -> None:
65+
key_value_encodings = [
66+
[encodings[f"past_key_{layer_n}_in"], encodings[f"past_value_{layer_n}_in"]]
67+
for layer_n in range(num_layers)
68+
]
69+
key_q = [
70+
_quantize_kv_cache(cache[0], encoding[0])
71+
for cache, encoding in zip(kvcache, key_value_encodings, strict=False)
72+
]
73+
value_q = [
74+
_quantize_kv_cache(cache[1], encoding[1])
75+
for cache, encoding in zip(kvcache, key_value_encodings, strict=False)
76+
]
77+
78+
key_cache = np.concatenate(key_q)
79+
value_cache = np.concatenate(value_q)
80+
81+
CACHE_FILE_SPEC = "IIBxHHH"
82+
CACHE_FILE_SPEC_SIZE = struct.calcsize(CACHE_FILE_SPEC)
83+
assert CACHE_FILE_SPEC_SIZE == 16
84+
DATATYPES = [
85+
np.uint8,
86+
np.uint16,
87+
np.uint32,
88+
np.uint64,
89+
np.int8,
90+
np.int16,
91+
np.int32,
92+
np.int64,
93+
None,
94+
np.float16,
95+
np.float32,
96+
np.float64,
97+
bool,
98+
]
99+
100+
_DTYPE_TO_ID = {np.dtype(t): i for i, t in enumerate(DATATYPES) if t is not None}
101+
with open(filename, "wb") as handle:
102+
dtype = _DTYPE_TO_ID.get(key_cache.dtype)
103+
if dtype is None:
104+
raise ValueError(
105+
f"Unsupported cache dtype: {key_cache.dtype}. "
106+
f"Supported: {list(_DTYPE_TO_ID.keys())}"
107+
)
108+
n_layer, n_head, n_tok, n_kv_dim = value_cache.shape
109+
num_tensors = n_layer * 2
110+
handle.write(
111+
struct.pack(
112+
CACHE_FILE_SPEC, num_tensors, 0xC0DE, dtype, n_head, n_kv_dim, n_tok
113+
)
114+
)
115+
key_cache.tofile(handle)
116+
value_cache.tofile(handle)
117+
118+
119+
class LLM_SSD_Base(LLMBase):
120+
"""Extends LLMBase with SSD (Self Speculative Decoding) forecast support."""
121+
122+
def __init__(
123+
self,
124+
*args: Any,
125+
ssd_forecast_ckpt: str | os.PathLike | Path | None = None,
126+
**kwargs: Any,
127+
) -> None:
128+
"""
129+
Parameters
130+
----------
131+
*args
132+
Positional arguments forwarded to LLMBase.
133+
ssd_forecast_ckpt
134+
Path to SSD forecast file. If provided, the SSD forecast token
135+
embeddings are concatenated to the model's embedding table.
136+
**kwargs
137+
Keyword arguments forwarded to LLMBase.
138+
"""
139+
super().__init__(*args, **kwargs)
140+
if ssd_forecast_ckpt is not None:
141+
ssd_param = torch.load(
142+
ssd_forecast_ckpt, map_location="cpu", weights_only=True
143+
)
144+
ssd_forecast_embeddings = ssd_param["forecast_embedding"]
145+
if len(ssd_forecast_embeddings) >= 1:
146+
embed_table = cast(torch.nn.Embedding, self.model.model.embed_tokens) # type: ignore[union-attr, unused-ignore]
147+
assert (
148+
embed_table.weight.shape[1] == ssd_forecast_embeddings.shape[1]
149+
), "Mismatching token embedding size for embed_tokens"
150+
embed_table.weight.data = torch.cat(
151+
[
152+
embed_table.weight.data,
153+
ssd_forecast_embeddings.to(embed_table.weight.dtype),
154+
],
155+
dim=0,
156+
)
157+
embed_table.num_embeddings = embed_table.weight.shape[0]
158+
159+
160+
class LLM_SSD_AIMETOnnx(LLM_AIMETOnnx):
161+
"""Extends LLM_AIMETOnnx with SSD (Self Speculative Decoding) support."""
162+
163+
@classmethod
164+
def prepare_genie_assets(
165+
cls,
166+
hub_device: hub.Device,
167+
checkpoint: str | os.PathLike | Path,
168+
llm_config: PretrainedConfig,
169+
context_length: int,
170+
model_list: list[str],
171+
output_path: Path,
172+
precision: Precision,
173+
encodings_path: str | os.PathLike | Path,
174+
input_specs: dict[str, Any],
175+
output_specs: dict[str, Any],
176+
) -> None:
177+
super().prepare_genie_assets(
178+
hub_device,
179+
checkpoint,
180+
llm_config,
181+
context_length,
182+
model_list,
183+
output_path,
184+
precision,
185+
encodings_path,
186+
input_specs,
187+
output_specs,
188+
)
189+
if cls.FPModel is None or not hasattr(cls.FPModel, "_ssd_forecast_ckpt"):
190+
return
191+
ssd_forecast_ckpt = cls.FPModel._ssd_forecast_ckpt()
192+
if ssd_forecast_ckpt is None:
193+
return
194+
195+
# Load SSD params once
196+
ssd_param = torch.load(ssd_forecast_ckpt, map_location="cpu", weights_only=True)
197+
ssd_prefix = ssd_param["forecast_prefix"].to(torch.float32)
198+
n_layer, _, _, _, len_prefix, _ = ssd_prefix.shape
199+
ssd_prefix_tuple = tuple(
200+
(ssd_prefix[i][0].permute(0, 1, 3, 2), ssd_prefix[i][1])
201+
for i in range(n_layer)
202+
)
203+
num_ssd_forecast_tokens = len(ssd_param["forecast_embedding"])
204+
205+
# Load activation_encodings (to scan for all 'past_key_*_in' layers)
206+
with open(encodings_path) as f:
207+
encodings = json.load(f)
208+
if isinstance(encodings["activation_encodings"], list):
209+
# Convert encodings to dictionary
210+
encodings["activation_encodings"] = {
211+
v["name"]: v for v in encodings["activation_encodings"]
212+
}
213+
actv_encodings = encodings["activation_encodings"]
214+
num_layers = sum(
215+
1
216+
for ae_key in actv_encodings
217+
if ae_key.startswith("past_value_") and ae_key.endswith("_in")
218+
)
219+
220+
# Create 'forecast-prefix' folder and save kvcache prefix
221+
ssd_prefix_des_dir = output_path / "forecast-prefix"
222+
shutil.rmtree(ssd_prefix_des_dir, ignore_errors=True)
223+
ssd_prefix_des_dir.mkdir(parents=True, exist_ok=True)
224+
_save_kv_cache(
225+
ssd_prefix_tuple,
226+
actv_encodings,
227+
str(ssd_prefix_des_dir / "kv-cache.primary.qnn-htp"),
228+
num_layers,
229+
)
230+
231+
# Update genie config with SSD params
232+
with open(output_path / GENIE_CONFIG_JSON) as f:
233+
genie_config = json.load(f)
234+
genie_config["dialog"]["type"] = "ssd-q1"
235+
genie_config["dialog"]["ssd-q1"] = {
236+
"version": 1,
237+
"ssd-version": 1,
238+
"forecast-token-count": num_ssd_forecast_tokens,
239+
"forecast-prefix": len_prefix,
240+
"forecast-prefix-name": ssd_prefix_des_dir.name,
241+
"branches": [3, 2],
242+
"n-streams": 1,
243+
"p-threshold": 0.0,
244+
}
245+
with open(output_path / GENIE_CONFIG_JSON, "w") as f:
246+
json.dump(genie_config, f, indent=4)

qai_hub_models/models/llama_v3_1_8b_instruct/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def test_cli_default_device_select_component(
234234
skip_download,
235235
skip_summary,
236236
target_runtime,
237+
decode_sequence_length=1,
237238
)
238239

239240

qai_hub_models/models/llama_v3_1_sea_lion_3_5_8b_r/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def test_cli_default_device_select_component(
226226
skip_download,
227227
skip_summary,
228228
target_runtime,
229+
decode_sequence_length=1,
229230
)
230231

231232

qai_hub_models/models/llama_v3_2_1b_instruct/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def test_cli_default_device_select_component(
244244
skip_download,
245245
skip_summary,
246246
target_runtime,
247+
decode_sequence_length=1,
247248
)
248249

249250

qai_hub_models/models/llama_v3_2_3b_instruct/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def test_cli_default_device_select_component(
223223
skip_download,
224224
skip_summary,
225225
target_runtime,
226+
decode_sequence_length=1,
226227
)
227228

228229

0 commit comments

Comments
 (0)