Skip to content

Commit 4b52ca2

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
fix MM nullptr from zero bias (pytorch#13523)
Summary: solve ``` *Error* Unhandled user exception: LoadProhibitedCause (0x00000000) ``` Differential Revision: D80487955
1 parent b82f8f3 commit 4b52ca2

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,19 @@ memcpy(void* dst, const void* src, size_t num_bytes) {
2121
}
2222

2323
void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
24+
ET_LOG(Info, "Attempting to allocate %zu bytes of temp memory", size);
2425
Result<void*> temp_mem_res = ctx.allocate_temp(size);
25-
return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
26+
if (temp_mem_res.ok()) {
27+
void* ptr = temp_mem_res.get();
28+
ET_LOG(Info, "Successfully allocated temp memory at %p", ptr);
29+
return ptr;
30+
} else {
31+
ET_LOG(
32+
Error,
33+
"Failed to allocate temp memory, error: 0x%x",
34+
static_cast<uint32_t>(temp_mem_res.error()));
35+
return nullptr;
36+
}
2637
}
2738

2839
// Quantize a fp32 value to an int8_t/uint8_t value

backends/cadence/hifi/operators/op_mm.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ Tensor& mm_out(
7979
(WORD32* __restrict__)kernels::allocate_temp_memory(
8080
ctx, (n * p) * sizeof(WORD32));
8181

82+
// Allocate zero-initialized bias for matmul function (it doesn't accept
83+
// NULL)
84+
FLOAT32* __restrict__ p_bias_zero =
85+
(FLOAT32* __restrict__)kernels::allocate_temp_memory(
86+
ctx, m * sizeof(FLOAT32));
87+
88+
// Initialize bias to zero since mm operation has no bias
89+
for (int i = 0; i < m; i++) {
90+
p_bias_zero[i] = 0.0f;
91+
}
92+
8293
WORD32 p_inp_shape[2];
8394
p_inp_shape[0] = n;
8495
p_inp_shape[1] = p;
@@ -109,19 +120,20 @@ Tensor& mm_out(
109120

110121
const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o;
111122

123+
// mm will always be converted to addmm and to linear, and move transpose to
124+
// graph
112125
WORD32 val = xa_nn_matmul_f32xf32_f32(
113126
p_out,
114127
p_mat1,
115128
p_vec,
116-
NULL,
129+
p_bias_zero,
117130
rows,
118131
cols1,
119132
row_stride1,
120133
vec_count,
121134
vec_offset,
122135
out_offset,
123136
out_stride);
124-
125137
return out;
126138
}
127139

0 commit comments

Comments
 (0)