Skip to content

Commit d6d53af

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
link mean dim kernels (pytorch#8053)
Summary: titled Differential Revision: D68845587
1 parent e8ee36c commit d6d53af

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

backends/cadence/fusion_g3/operators/op_mean.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ namespace impl {
2828
namespace G3 {
2929
namespace native {
3030

31+
template <typename CTYPE_IN, typename CTYPE_OUT>
32+
void mean_out_(
33+
const Tensor& in,
34+
optional<ArrayRef<int64_t>> dim_list,
35+
__ET_UNUSED bool keepdim,
36+
__ET_UNUSED optional<ScalarType> dtype,
37+
Tensor& out) {
38+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
39+
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);
40+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
41+
CTYPE_OUT sum = 0;
42+
if (in.numel() > 0) {
43+
sum = torch::executor::map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
44+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
45+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
46+
in,
47+
dim_list,
48+
out_ix);
49+
}
50+
out_data[out_ix] = sum / static_cast<float>(num);
51+
}
52+
}
53+
3154
int prepare_data(
3255
const Tensor& in,
3356
Tensor& out,
@@ -60,7 +83,7 @@ int prepare_data(
6083
return num_axis_dims;
6184
}
6285

63-
Tensor& mean_dim_out(
86+
Tensor& mean_out(
6487
KernelRuntimeContext& ctx,
6588
const Tensor& in,
6689
optional<ArrayRef<int64_t>> dim_list,
@@ -169,29 +192,8 @@ Tensor& mean_dim_out(
169192
InvalidArgument,
170193
out);
171194

172-
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
173-
ET_SWITCH_FLOATH_TYPES(
174-
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
175-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
176-
const size_t num =
177-
torch::executor::get_reduced_dim_product(in, dim_list);
178-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
179-
CTYPE_OUT sum = 0;
180-
if (in.numel() > 0) {
181-
sum = torch::executor::
182-
map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
183-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
184-
[](CTYPE_OUT outv, CTYPE_OUT acc) {
185-
return acc + outv;
186-
},
187-
in,
188-
dim_list,
189-
out_ix);
190-
}
191-
out_data[out_ix] = sum / static_cast<float>(num);
192-
}
193-
});
194-
});
195+
mean_out_<float, float>(in, dim_list, keepdim, dtype, out);
196+
return out;
195197
}
196198

197199
return out;

0 commit comments

Comments
 (0)