Skip to content

Commit cd8cfad

Browse files
committed
Added custom_dtype support for wav2vec2
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
1 parent a81f486 commit cd8cfad

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4059,7 +4059,7 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
40594059
seq_len = constants.WAV2VEC2_MAX_SEQ_LEN
40604060

40614061
example_inputs = {
4062-
"input_values": torch.zeros((bs, seq_len), dtype=torch.float32),
4062+
"input_values": torch.zeros((bs, seq_len), dtype=self.model.config.torch_dtype),
40634063
}
40644064

40654065
dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}}
@@ -4205,6 +4205,8 @@ def cloud_ai_100_feature_generate(
42054205
input_values = np.array(
42064206
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
42074207
)
4208+
needed_dtype = self.model.config.torch_dtype
4209+
input_values = input_values.astype(DTYPE_TO_STRING_MAP[needed_dtype])
42084210
inputs = dict(input_values=input_values)
42094211
outputs = self.qpc_session.run(inputs)
42104212

0 commit comments

Comments
 (0)