Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
generate_mdp_partition_config,
hash_dict_params,
load_json,
to_named_specializations,
)
from QEfficient.utils.export_utils import export_wrapper

Expand Down Expand Up @@ -429,6 +430,9 @@ def _compile(

return self.qpc_path

# Pop internal-only kwargs that must never reach the compiler command line.
spec_module_name = compiler_options.pop("specialization_module_name", None)

command = (
constants.COMPILER
+ [
Expand Down Expand Up @@ -500,7 +504,7 @@ def _compile(
if specializations is not None:
specializations_json = compile_dir / "specializations.json"
specializations_data = {
"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]
"specializations": to_named_specializations(specializations, module_name=spec_module_name)
}
create_json(str(specializations_json), specializations_data)
command.append(f"-network-specialization-config={specializations_json}")
Expand Down
35 changes: 17 additions & 18 deletions QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,34 @@

from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.utils import constants
from QEfficient.utils._utils import load_json, load_yaml
from QEfficient.utils._utils import load_json, load_yaml, to_named_specializations
from QEfficient.utils.logging_utils import logger


def create_and_dump_specializations(
batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None
):
# Create specialization file.
specializations = {
"specializations": [
{
"batch_size": str(batch_size),
"seq_len": str(prompt_len),
"ctx_len": str(ctx_len),
},
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
]
}
# If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS
# Build flat specialization entries first, then convert to named format.
flat_specs = [
{
"batch_size": str(batch_size),
"seq_len": str(prompt_len),
"ctx_len": str(ctx_len),
},
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
]
# If continuous batching is enabled by providing full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS
if full_batch_size is not None:
specializations["specializations"][0]["full_batch_size"] = str(full_batch_size)
specializations["specializations"][1]["full_batch_size"] = str(full_batch_size)
specializations["specializations"][1]["batch_size"] = str(full_batch_size)
flat_specs[0]["full_batch_size"] = str(full_batch_size)
flat_specs[1]["full_batch_size"] = str(full_batch_size)
flat_specs[1]["batch_size"] = str(full_batch_size)

# To handle repetative input in specializations when prompt_len is 1
# To handle repetitive input in specializations when prompt_len is 1
if prompt_len == 1 and full_batch_size is None:
specializations["specializations"].pop()
flat_specs.pop()

# Dump
specializations = {"specializations": to_named_specializations(flat_specs)}
with open(path, "w") as file:
json.dump(specializations, file, indent=4)

Expand Down
4 changes: 2 additions & 2 deletions QEfficient/compile/qnn_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shutil
from typing import Dict, List, Optional

from QEfficient.utils._utils import create_json, execute_command, load_json
from QEfficient.utils._utils import create_json, execute_command, load_json, to_named_specializations
from QEfficient.utils.constants import QnnConstants
from QEfficient.utils.generate_qnn_network_specialization_config import (
generate_data_format_config,
Expand Down Expand Up @@ -423,7 +423,7 @@ def compile(
specializations_json = qpc_base_path / "specializations.json"
with open(specializations_json, "w") as fp:
json.dump(
{"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]},
{"specializations": to_named_specializations(specializations)},
fp,
indent=4,
)
Expand Down
8 changes: 7 additions & 1 deletion QEfficient/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import PIL.Image
from tqdm import tqdm

from QEfficient.utils._utils import load_json
from QEfficient.utils._utils import load_json, to_named_specializations
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -173,6 +173,9 @@ def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
else:
specializations = [specializations]

# Convert flat dicts to named {name, symbols} format using the module name.
specializations = to_named_specializations(specializations, module_name=module_name)

if module_obj.qpc_path is None:
# Compile with prepared specializations
module_obj.compile(specializations=specializations, **compile_kwargs)
Expand Down Expand Up @@ -226,6 +229,9 @@ def compile_modules_sequential(
else:
specializations = [specializations]

# Convert flat dicts to named {name, symbols} format using the module name.
specializations = to_named_specializations(specializations, module_name=module_name)

if module_obj.qpc_path is None:
# Compile with prepared specializations
module_obj.compile(specializations=specializations, **compile_kwargs)
Expand Down
10 changes: 7 additions & 3 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,15 @@ def get_compilation_dims(qpc_path: str) -> Tuple[int, int, Optional[int]]:
else:
raise FileNotFoundError(f"expected specializations.json file at path, {qpc_base_path}")

compilation_batch_size = int(data["specializations"][0]["batch_size"])
compilation_ctx_len = int(data["specializations"][0]["ctx_len"])
if compilation_fbs := data["specializations"][0].get("full_batch_size", None):
# Support both the legacy flat format and the new {name, symbols} format.
first = data["specializations"][0]
spec = first.get("symbols", first)
compilation_batch_size = int(spec["batch_size"])
compilation_ctx_len = int(spec["ctx_len"])
if compilation_fbs := spec.get("full_batch_size", None):
compilation_fbs = int(compilation_fbs)
return compilation_batch_size, compilation_ctx_len, compilation_fbs
return compilation_batch_size, compilation_ctx_len, compilation_fbs


def get_input_prompts(prompt: str, prompts_txt_file_path: str) -> List[str]:
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,7 @@ def compile(
compile_dir=compile_dir,
compile_only=True,
specializations=specializations["vision"],
specialization_module_name="Vision",
convert_to_fp16=True,
mxfp6_matmul=constants.VISION_MXFP6_MATMUL,
mdp_ts_num_devices=num_devices,
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
onnx_exists,
padding_check_and_fix,
qpc_exists,
to_named_specializations,
)
from QEfficient.utils.hash_utils import ( # noqa: F401
create_export_hash,
Expand Down
98 changes: 98 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,101 @@ def custom_format_warning(msg, category, *args, **kwargs):
YELLOW = "\033[93m"
RESET = "\033[0m"
return f"{YELLOW}[Warning]: {msg}{RESET}\n"


def _infer_specialization_name(spec: Dict, index: int, module_name: Optional[str] = None) -> str:
"""
Infer a human-readable name for a specialization entry.

The naming convention follows the backend team's request:

- If ``module_name`` is provided it is used directly as the graph name
(e.g. ``"text_encoder"``, ``"vae_decoder"``). For multi-specialization
modules that carry a ``model_type`` key the name becomes
``f"{module_name}_model_type_{value}"`` (e.g. Wan transformer).
- "Prefill" for entries whose ``seq_len`` value is not "1" / 1.
- "Decode" for entries whose ``seq_len`` is "1" / 1.
- "Encoder" for entries that contain ``encoder_ctx_len`` but no ``seq_len``
(e.g. Whisper encoder).
- "Embedding" for entries that contain ``sequence_length`` but no ``seq_len``
(e.g. BERT / text-embedding models).
- A generic fallback ``f"Graph_{index}"`` for anything else.

Note: "Vision" is **not** inferred from spec keys because VLM vision specs
vary too much across models (some carry ``seq_len`` + ``ctx_len``, some do
not). Callers that know they are compiling a vision encoder should pass
``module_name="Vision"`` explicitly.

Parameters
----------
spec : Dict
A single flat specialization dictionary (key → value).
index : int
Zero-based position of this entry in the specializations list, used only
for the generic fallback name.
module_name : str, optional
Explicit graph name hint. When provided it takes priority over all
heuristics. Used by diffusers pipeline modules and VLM vision/lang
compile call sites.

Returns
-------
str
The inferred graph name.
"""
# Explicit name hint — used by diffusers modules and VLM vision/lang paths.
if module_name is not None:
if "model_type" in spec:
return f"{module_name}_model_type_{spec['model_type']}"
return module_name

if "seq_len" not in spec:
if "encoder_ctx_len" in spec:
return "Encoder"
if "sequence_length" in spec:
return "Embedding"
return f"Graph_{index}"
seq_len = spec["seq_len"]
if str(seq_len) == "1":
return "Decode"
return "Prefill"


def to_named_specializations(specializations: List[Dict], module_name: Optional[str] = None) -> List[Dict]:
"""
Convert a flat list of specialization dicts to the nested ``{name, symbols}``
format expected by the backend compiler.

Old format (one entry)::

{"batch_size": "1", "seq_len": "128", "ctx_len": "4096"}

New format (one entry)::

{"name": "Prefill", "symbols": {"batch_size": "1", "seq_len": "128", "ctx_len": "4096"}}

For diffusers pipeline modules pass ``module_name`` so the graph name reflects
the module (e.g. ``"text_encoder"``, ``"vae_decoder"``, ``"transformer_model_type_1"``).

Parameters
----------
specializations : List[Dict]
List of flat specialization dicts (values may be int or str).
module_name : str, optional
Pipeline module name forwarded to ``_infer_specialization_name``.

Returns
-------
List[Dict]
List of ``{"name": str, "symbols": Dict[str, str]}`` dicts.
"""
result = []
for index, spec in enumerate(specializations):
# Idempotent: already in named format, pass through unchanged.
if set(spec.keys()) == {"name", "symbols"}:
result.append(spec)
continue
name = _infer_specialization_name(spec, index, module_name=module_name)
symbols = {k: str(v) for k, v in spec.items()}
result.append({"name": name, "symbols": symbols})
return result
Loading
Loading