Skip to content

Commit b190992

Browse files
authored
Fix Melotts csim bug by changing torch.sum op (#2894)
Csim can run AIHub binaries successfully now after changing `y_lengths = torch.sum(w_ceil, [1, 2])` to `y_lengths = torch.sum(torch.sum(w_ceil, dim=2), dim=1)` for the encoder, it is a QNN bug. QNN can't identify following lines, after converting to context binary, the y_lengths is computed wrongly, w_ceil.shape = [1, 1, 512]. `y_lengths = torch.sum(w_ceil, [1, 2])` `y_lengths = torch.tensor([w_ceil.detach().numpy().sum() ], dtype=torch.float32)` `y_lengths = torch.tensor([w_ceil.squeeze().cumsum(dim=0)[-1] ], dtype=torch.float32)`
1 parent 0f28d4f commit b190992

File tree

7 files changed

+106
-54
lines changed

7 files changed

+106
-54
lines changed

qai_hub_models/models/_shared/hf_whisper/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
import base64
9+
import os
810
from abc import abstractmethod
911
from typing import Any, cast
1012

@@ -19,10 +21,12 @@
1921
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
2022
from typing_extensions import Self
2123

24+
from qai_hub_models.configs.metadata_yaml import ModelMetadata
2225
from qai_hub_models.models._shared.hf_whisper.model_adaptation import (
2326
monkey_patch_model,
2427
)
2528
from qai_hub_models.models.common import Precision, TargetRuntime
29+
from qai_hub_models.utils.asset_loaders import CachedWebModelAsset
2630
from qai_hub_models.utils.base_model import (
2731
BaseModel,
2832
CollectionModel,
@@ -52,6 +56,12 @@
5256

5357
# Mask neg
5458
MASK_NEG = -100.0
59+
TIKTOKEN_URL = CachedWebModelAsset(
60+
"https://raw.githubusercontent.com/openai/whisper/839639a223b92ad61851baae9ad8a695ccb41ce5/whisper/assets/multilingual.tiktoken",
61+
"hf_whisper_shared",
62+
1,
63+
"multilingual.tiktoken",
64+
)
5565

5666

5767
class HfWhisperEncoder(BaseModel):
@@ -363,6 +373,26 @@ def from_pretrained(cls) -> Self:
363373
decoder = HfWhisperDecoder(config, whisper.get_decoder())
364374
return cls(encoder, decoder, config, cls.get_hf_whisper_version())
365375

376+
def write_supplementary_files(
377+
self, output_dir: str | os.PathLike, metadata: ModelMetadata
378+
) -> None:
379+
whisper_tiktoken = TIKTOKEN_URL.fetch()
380+
381+
with open(whisper_tiktoken, "rb") as f:
382+
lines = f.readlines()
383+
384+
with open(os.path.join(output_dir, "vocab.bin"), "wb") as f:
385+
for line in lines:
386+
l = line.split()
387+
if len(l) < 2:
388+
continue
389+
token = base64.b64decode(line.split()[0])
390+
if b"\0" in token:
391+
f.write(token)
392+
else:
393+
f.write(token)
394+
f.write(b"\0")
395+
366396

367397
def get_feature_extractor(
368398
hf_whisper_version: str = "openai/whisper-small",

qai_hub_models/models/_shared/melotts/app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,11 @@ def tts_to_file(
231231
length_scale_pt,
232232
noise_scale_w_pt,
233233
)
234-
235234
# Flow input
236235
y_mask = torch.unsqueeze(
237-
torch.arange(MAX_SEQ_LEN * 3) < y_lengths[:, None], dim=1
236+
torch.arange(MAX_SEQ_LEN * 3) < y_lengths.unsqueeze(dim=-1), dim=1
238237
).to(torch.float32)
238+
239239
attn_mask = x_mask.unsqueeze(dim=2) * y_mask.unsqueeze(dim=-1)
240240
attn = generate_path(w_ceil, attn_mask)
241241
attn_squeezed = attn.squeeze(1).to(torch.float32)
@@ -352,7 +352,7 @@ def get_calibration_data(
352352
)
353353

354354
y_mask = torch.unsqueeze(
355-
torch.arange(MAX_SEQ_LEN * 3) < y_lengths[:, None], dim=1
355+
torch.arange(MAX_SEQ_LEN * 3) < y_lengths.unsqueeze(dim=-1), dim=1
356356
).to(torch.float32)
357357
attn_mask = x_mask.unsqueeze(dim=2) * y_mask.unsqueeze(dim=-1)
358358
attn = generate_path(w_ceil, attn_mask)
@@ -398,7 +398,7 @@ def get_calibration_data(
398398
)
399399

400400
y_mask = torch.unsqueeze(
401-
torch.arange(MAX_SEQ_LEN * 3) < y_lengths[:, None], dim=1
401+
torch.arange(MAX_SEQ_LEN * 3) < y_lengths.unsqueeze(dim=-1), dim=1
402402
).to(torch.float32)
403403
attn_mask = x_mask.unsqueeze(dim=2) * y_mask.unsqueeze(dim=-1)
404404
attn = generate_path(w_ceil, attn_mask)

qai_hub_models/models/_shared/melotts/meloTTS_encoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ def forward(
106106

107107
x = self.encoder(x * x_mask, x_mask, g=g)
108108

109-
stats = self.proj(x) * x_mask
110-
m, logs = torch.split(stats, self.out_channels, dim=1)
109+
stats = self.proj(x)
110+
m, logs = torch.chunk(stats, 2, dim=1)
111+
# m, logs = torch.split(stats, self.out_channels, dim=1) # this line has the same effect as above line
111112
return x, m, logs, x_mask
112113

113114

qai_hub_models/models/_shared/melotts/meloTTS_metadata_json.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class VoiceSpec(BaseQAIHMConfig):
3838
sample_rate: int = 44100
3939
language_code: int = 0
4040
description: str
41+
capabilities: TTSCapabilities
4142

4243

4344
class TTSCapabilities(BaseQAIHMConfig):
@@ -53,15 +54,22 @@ class TTSCapabilities(BaseQAIHMConfig):
5354
supports_resampling: bool = False
5455

5556

57+
class QNNVersion(BaseQAIHMConfig):
58+
"""Version of QNN SDK."""
59+
60+
major: int
61+
minor: int
62+
patch: int = 0
63+
64+
5665
class RuntimeInfo(BaseQAIHMConfig):
5766
"""Runtime configuration information."""
5867

5968
language: str
60-
qnn_version_major: int
61-
qnn_version_minor: int
62-
qnn_version_patch: int
69+
qnn_version: QNNVersion
6370
arch_bit: int = 64
6471
scratch_mem_size_req: int = 3200000
72+
is_model_quantized: bool = False
6573

6674

6775
class ModelAssets(BaseQAIHMConfig):
@@ -92,7 +100,6 @@ class TTSMetadata(BaseQAIHMConfig):
92100
version: str = "1.0.0"
93101
description: str
94102
voices: list[VoiceSpec]
95-
capabilities: TTSCapabilities
96103
model_type: str = "melo"
97104
runtime: RuntimeInfo | None = None
98105
assets: ModelAssets | None = None
@@ -107,11 +114,11 @@ def from_melo_tts_model(
107114
model_name: str,
108115
display_name: str,
109116
description: str,
117+
tool_versions: ToolVersions,
110118
voice_specs: list[VoiceSpec] | None = None,
111119
capabilities: TTSCapabilities | None = None,
112120
runtime: RuntimeInfo | None = None,
113121
assets: ModelAssets | None = None,
114-
tool_versions: ToolVersions | None = None,
115122
) -> TTSMetadata:
116123
"""
117124
Construct a ``TTSMetadata`` object from the information
@@ -127,6 +134,8 @@ def from_melo_tts_model(
127134
Human-readable name.
128135
description
129136
Short description of the model.
137+
tool_versions
138+
Optional tool-version information.
130139
voice_specs
131140
List of :class:`VoiceSpec` describing each voice.
132141
capabilities
@@ -135,8 +144,6 @@ def from_melo_tts_model(
135144
Optional runtime information; if omitted a minimal default is used.
136145
assets
137146
Optional asset paths; if omitted a minimal default is used.
138-
tool_versions
139-
Optional tool-version information.
140147
141148
Returns
142149
-------
@@ -146,23 +153,18 @@ def from_melo_tts_model(
146153
if capabilities is None:
147154
capabilities = TTSCapabilities()
148155
if runtime is None:
149-
# Default runtime - QNN version is taken from ``tool_versions`` if present.
150-
qnn_version = {"major": 2, "minor": 33, "patch": 0}
151-
if tool_versions and tool_versions.qairt is not None:
152-
qnn_version = {
153-
"major": int(tool_versions.qairt.framework.major),
154-
"minor": int(tool_versions.qairt.framework.minor),
155-
"patch": int(
156+
assert tool_versions.qairt is not None
157+
runtime = RuntimeInfo(
158+
language=LANGUAGE_MAP[language],
159+
qnn_version=QNNVersion(
160+
major=int(tool_versions.qairt.framework.major),
161+
minor=int(tool_versions.qairt.framework.minor),
162+
patch=int(
156163
tool_versions.qairt.framework.patch
157164
if tool_versions.qairt.framework.patch
158165
else 0
159166
),
160-
}
161-
runtime = RuntimeInfo(
162-
language=LANGUAGE_MAP[language],
163-
qnn_version_major=qnn_version["major"],
164-
qnn_version_minor=qnn_version["minor"],
165-
qnn_version_patch=qnn_version["patch"],
167+
),
166168
)
167169
if assets is None:
168170
assets = ModelAssets()
@@ -173,6 +175,7 @@ def from_melo_tts_model(
173175
language=LANGUAGE_MAP[language],
174176
language_name=language.capitalize(),
175177
description=f"Default voice for {language.capitalize()}",
178+
capabilities=capabilities,
176179
)
177180
]
178181

@@ -181,7 +184,6 @@ def from_melo_tts_model(
181184
display_name=display_name,
182185
description=description,
183186
voices=voice_specs,
184-
capabilities=capabilities,
185187
runtime=runtime,
186188
assets=assets,
187189
)
@@ -249,6 +251,6 @@ def create_tts_metadata(
249251
model_name=model_name,
250252
display_name=display_name,
251253
description=description,
252-
assets=assets,
253254
tool_versions=metadata.tool_versions,
255+
assets=assets,
254256
)

qai_hub_models/models/_shared/melotts/model.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_input_spec() -> InputSpec:
126126
}
127127

128128
def _sample_inputs_impl(
129-
self, input_spec: InputSpec | None = None
129+
self, input_spec: InputSpec | None = None, **kwargs: Any
130130
) -> SampleInputsType:
131131
"""
132132
This is a default implementation that returns a single random data array
@@ -212,6 +212,7 @@ def forward(
212212
# This does not use a minimum of 0 because some models only have 1 speaker. That would result in a clamp(0, 0) operator, which is invalid in QNN.
213213
sid = torch.clamp(sid, max=self.model.emb_g.num_embeddings - 1)
214214
g = self.model.emb_g(sid).unsqueeze(-1)
215+
215216
x, m_p, logs_p, x_mask = self.encoder.forward(
216217
x, x_lengths, tone, language, bert, ja_bert, g=g
217218
)
@@ -223,9 +224,12 @@ def forward(
223224
logw = logw.masked_fill(x_mask == 0, -1e9)
224225

225226
w = torch.exp(logw + torch.log(self.scale * length_scale)) * x_mask
226-
w_ceil = torch.ceil(w)
227-
y_lengths = torch.sum(w_ceil, [1, 2])
228-
227+
w_ceil = torch.ceil(w) # shape: [1, 1, 512]
228+
# y_lengths = torch.sum(w_ceil, [1, 2]) # after converting to context binary, QNN can't sum correctly
229+
# y_lengths = torch.tensor([w_ceil.detach().numpy().sum() ], dtype=torch.float32) # QNN can't sum correctly
230+
# y_lengths = torch.tensor([w_ceil.squeeze().cumsum(dim=0)[-1] ], dtype=torch.float32) # QNN can't sum correctly
231+
y_lengths = torch.sum(torch.sum(w_ceil, dim=2), dim=1) # QNN sums correctly
232+
# TODO https://jira-dc.qualcomm.com/jira/projects/AISW/issues/AISW-175294
229233
return y_lengths, x_mask, m_p, logs_p, g, w_ceil
230234

231235
def sdp_forward(
@@ -251,25 +255,21 @@ def sdp_forward(
251255
shape of (1, 1, MAX_SEQ_LEN)
252256
"""
253257
sdp = self.model.sdp
254-
x = torch.detach(x)
255258
assert hasattr(sdp, "pre") and callable(sdp.pre)
256-
x = sdp.pre(x)
257259
assert hasattr(sdp, "cond") and callable(sdp.cond)
260+
assert hasattr(sdp, "convs") and callable(sdp.convs)
261+
assert hasattr(sdp, "proj") and callable(sdp.proj)
262+
assert hasattr(sdp, "flows") and isinstance(sdp.flows, Iterable)
263+
x = torch.detach(x)
264+
x = sdp.pre(x)
258265
if g is not None:
259266
g = torch.detach(g)
260267
x = x + sdp.cond(g)
261-
assert hasattr(sdp, "convs") and callable(sdp.convs)
262268
x = sdp.convs(x, x_mask)
263-
assert hasattr(sdp, "proj") and callable(sdp.proj)
264269
x = sdp.proj(x) * x_mask
265270

266-
assert hasattr(sdp, "flows") and isinstance(sdp.flows, Iterable)
267271
flows = list(sdp.flows)[::-1]
268-
flows = [
269-
*flows[:-2],
270-
flows[-1],
271-
]
272-
272+
flows = [*flows[:-2], flows[-1]]
273273
z = self.sdp_noise[:, :, : x.size(2)] * noise_scale_w
274274

275275
half_channels = None
@@ -304,16 +304,13 @@ def get_hub_compile_options(
304304
device: Device | None = None,
305305
context_graph_name: str | None = None,
306306
) -> str:
307-
if target_runtime.qairt_version_changes_compilation:
308-
other_compile_options += " --quantize_io false "
309307
compile_options = super().get_hub_compile_options(
310308
target_runtime,
311309
precision,
312310
other_compile_options,
313311
device,
314312
context_graph_name="encoder",
315313
)
316-
# # Must use --truncate_64bit_io when input tensors have type int64.
317314
if target_runtime != TargetRuntime.ONNX:
318315
compile_options += " --truncate_64bit_tensors --truncate_64bit_io "
319316
return compile_options
@@ -427,7 +424,7 @@ def get_hub_compile_options(
427424
context_graph_name: str | None = None,
428425
) -> str:
429426
if target_runtime.qairt_version_changes_compilation:
430-
other_compile_options += " --quantize_io false "
427+
other_compile_options += " --quantize_io "
431428
return super().get_hub_compile_options(
432429
target_runtime,
433430
precision,
@@ -494,7 +491,7 @@ def get_hub_compile_options(
494491
context_graph_name: str | None = None,
495492
) -> str:
496493
if target_runtime.qairt_version_changes_compilation:
497-
other_compile_options += " --quantize_io false "
494+
other_compile_options += " --quantize_io "
498495
return super().get_hub_compile_options(
499496
target_runtime,
500497
precision,
@@ -943,8 +940,6 @@ def get_hub_compile_options(
943940
device: Device | None = None,
944941
context_graph_name: str | None = None,
945942
) -> str:
946-
if target_runtime.qairt_version_changes_compilation:
947-
other_compile_options += " --quantize_io false "
948943
compile_options = super().get_hub_compile_options(
949944
target_runtime,
950945
precision,
@@ -956,10 +951,6 @@ def get_hub_compile_options(
956951
compile_options += " --truncate_64bit_tensors --truncate_64bit_io "
957952
return compile_options
958953

959-
@staticmethod
960-
def calibration_dataset_name() -> str:
961-
return "common_voice_text"
962-
963954
@staticmethod
964955
def component_precision() -> Precision:
965956
return Precision.float

0 commit comments

Comments
 (0)