Skip to content

Commit 8680d1a

Browse files
quic-dhirajkuasmigosw
authored andcommitted
Added check to not pass Custom_IO yaml when model weight and pkv are both in bfloat16.
Added a patch incloud infer to map bfloat16 or 11 key type to np.float16 for AI200 inference. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
1 parent 2915e49 commit 8680d1a

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,12 +538,19 @@ def _compile(
538538
command.append(f"-network-specialization-config={specializations_json}")
539539

540540
# Write custom_io.yaml file
541+
model_in_bfloat16 = self.config.torch_dtype == torch.bfloat16
542+
pkv_in_bfloat16 = any("past_" in key and "bfloat16" in value for key, value in custom_io.items())
541543
if custom_io is not None:
542544
custom_io_yaml = compile_dir / "custom_io.yaml"
543545
with open(custom_io_yaml, "w") as fp:
544546
for io_name, dtype in custom_io.items():
545547
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
546-
command.append(f"-custom-IO-list-file={custom_io_yaml}")
548+
if model_in_bfloat16 and pkv_in_bfloat16:
549+
logger.warning(
550+
"Model and Past KV types are both bfloat16. Custom IO list file will be ignored during compile."
551+
)
552+
else:
553+
command.append(f"-custom-IO-list-file={custom_io_yaml}")
547554

548555
command.append(f"-aic-binary-dir={qpc_path}")
549556
logger.info(f"Running compiler: {' '.join(command)}")

QEfficient/generation/cloud_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565

6666
# Build dtype mapping once (depends on aicapi constants)
6767
self.aic_to_np_dtype_mapping = {
68+
getattr(aicapi, "BFLOAT16_TYPE", 11): np.dtype(np.float16),
6869
aicapi.FLOAT_TYPE: np.dtype(np.float32),
6970
aicapi.FLOAT_16_TYPE: np.dtype(np.float16),
7071
aicapi.INT8_Q_TYPE: np.dtype(np.int8),

0 commit comments

Comments
 (0)