Skip to content

Commit d77f045

Browse files
authored
Arm backend: Fix index bug in InsertRescaleInt32Pass (pytorch#15446)
qparams = node.meta["input_qparams"] maps arg index to qparam, which is not guaranteed to go 0..N. In the pass, range(len(qparams)) was used, which will always create indices 0..N. This caused a KeyError in several torch audio models. Signed-off-by: Erik Lundell <[email protected]>
1 parent 36ab971 commit d77f045

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,17 @@ def _get_inputs_rescaled_qparams(
140140
min_scale = min(
141141
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
142142
)
143-
qparams = {
144-
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
145-
}
143+
qparams = {i: self._int32_qargs(min_scale) for i in input_qparams.keys()}
146144
elif target in [
147145
exir_ops.edge.aten.add.Tensor,
148146
exir_ops.edge.aten.sub.Tensor,
149147
]:
150-
if input_qparams[0].dtype != input_qparams[1].dtype:
148+
keys = list(input_qparams)
149+
if len(keys) < 2:
150+
raise ValueError(f"Expected two input qparams, got: {input_qparams}.")
151+
if input_qparams[keys[0]].dtype != input_qparams[keys[1]].dtype:
151152
raise ValueError(
152-
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
153+
f"Mismatch in dtype args: {input_qparams[keys[0]].dtype} != {input_qparams[keys[1]].dtype}"
153154
)
154155

155156
# We are handling two INT8 or two INT16 numbers. For INT8, if the
@@ -167,19 +168,19 @@ def _get_inputs_rescaled_qparams(
167168
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
168169

169170
# Select shift based on input dtype.
170-
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20
171+
shift_bits = 12 if input_qparams[keys[0]].dtype == torch.int16 else 20
171172

172173
scale = max_scale_2x / (1 << shift_bits)
173-
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
174+
qparams = {i: self._int32_qargs(scale) for i in input_qparams.keys()}
174175
elif target in [
175176
exir_ops.edge.aten.mul.Tensor,
176177
exir_ops.edge.aten.sum.dim_IntList,
177178
]:
178179
# The input scales do not need to be adjusted for these ops; they
179180
# can remain the same.
180181
qparams = {
181-
i: self._int32_qargs(input_qparams[i].get_scale_per_tensor())
182-
for i in range(len(input_qparams))
182+
i: self._int32_qargs(qp.get_scale_per_tensor())
183+
for i, qp in input_qparams.items()
183184
}
184185
else:
185186
raise ValueError(f"Not a valid target: {target}")

0 commit comments

Comments
 (0)