Skip to content

Commit d13c6b6

Browse files
swolchokZonglin Peng
authored andcommitted
Support Half/BFloat16 in softmax (pytorch#7867)
Partial fix for pytorch#7748.
1 parent 5fcc092 commit d13c6b6

File tree

2 files changed

+52
-42
lines changed

2 files changed

+52
-42
lines changed

kernels/portable/cpu/op_softmax.cpp

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,48 @@ Tensor& softmax_out(
4242
// Adjust for negative dim
4343
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
4444

45-
ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
46-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
45+
ET_SWITCH_FLOATHBF16_TYPES(
46+
in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
47+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
48+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
4849

49-
apply_over_dim(
50-
[in_data, out_data](
51-
const size_t size, const size_t stride, const size_t base) {
52-
// calculate max in softmax dim. During softmax computation each
53-
// value is subtracted by the maximum in value before calling exp
54-
// to preserve numerical stability.
55-
const CTYPE max_in = apply_unary_reduce_fn(
56-
[](const CTYPE val_in, CTYPE val_accum) {
57-
return std::max(val_in, val_accum);
58-
},
59-
in_data + base,
60-
size,
61-
stride);
50+
apply_over_dim(
51+
[in_data, out_data](
52+
const size_t size, const size_t stride, const size_t base) {
53+
// calculate max in softmax dim. During softmax computation each
54+
// value is subtracted by the maximum in value before calling exp
55+
// to preserve numerical stability.
56+
const CTYPE max_in = apply_unary_reduce_fn(
57+
[](const CTYPE val_in, CTYPE val_accum) {
58+
return std::max(val_in, val_accum);
59+
},
60+
in_data + base,
61+
size,
62+
stride);
6263

63-
const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
64-
[max_in](const CTYPE val_in) {
65-
return std::exp(val_in - max_in);
66-
},
67-
[](const CTYPE mapped_in, CTYPE val_accum) {
68-
return val_accum + mapped_in;
69-
},
70-
in_data + base,
71-
size,
72-
stride);
64+
const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
65+
[max_in](const CTYPE val_in) {
66+
return std::exp(val_in - max_in);
67+
},
68+
[](const CTYPE mapped_in, CTYPE val_accum) {
69+
return val_accum + mapped_in;
70+
},
71+
in_data + base,
72+
size,
73+
stride);
7374

74-
apply_unary_map_fn(
75-
[max_in, temp_sum](const CTYPE val_in) {
76-
return std::exp(val_in - max_in) / temp_sum;
77-
},
78-
in_data + base,
79-
out_data + base,
80-
size,
81-
stride);
82-
},
83-
in,
84-
dim);
85-
});
75+
apply_unary_map_fn(
76+
[max_in, temp_sum](const CTYPE val_in) {
77+
return std::exp(val_in - max_in) / temp_sum;
78+
},
79+
in_data + base,
80+
out_data + base,
81+
size,
82+
stride);
83+
},
84+
in,
85+
dim);
86+
});
8687

8788
return out;
8889
}

kernels/test/op_softmax_test.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ class OpSoftmaxOutTest : public OperatorTest {
6161
});
6262
// clang-format on
6363

64-
EXPECT_TENSOR_CLOSE(out, expected);
64+
if (DTYPE == ScalarType::BFloat16) {
65+
EXPECT_TENSOR_CLOSE_WITH_TOL(
66+
out,
67+
expected,
68+
1e-2,
69+
executorch::runtime::testing::internal::kDefaultAtol);
70+
} else {
71+
EXPECT_TENSOR_CLOSE(out, expected);
72+
}
6573
}
6674
};
6775

@@ -100,9 +108,10 @@ TEST_F(OpSoftmaxOutTest, HalfSupport) {
100108
}
101109

102110
TEST_F(OpSoftmaxOutTest, AllDtypesSupported) {
103-
test_dtype<float, ScalarType::Float>();
104-
test_dtype<double, ScalarType::Double>();
105-
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
111+
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
112+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
113+
#undef TEST_ENTRY
114+
// TODO: Also add tests for complex, quantized, and other types. Easiest
106115
// way to do that would be to make TensorFactory support zeros() and ones()
107116
// for those types.
108117
}

0 commit comments

Comments
 (0)