Skip to content

Commit 4f1ea07

Browse files
kuharfelipepiovezan
authored andcommitted
[ADT] Add sum_of and product_of accumulate wrappers (llvm#162129)
Also extend the `accumulate` wrapper to accept a binary operator. The goal is to the most common usage of `std::accumulate` across the codebase -- calculating either the sum of or the product of all values. (cherry picked from commit 454ef02)
1 parent 48a5800 commit 4f1ea07

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,6 +1738,28 @@ template <typename R, typename E> auto accumulate(R &&Range, E &&Init) {
17381738
std::forward<E>(Init));
17391739
}
17401740

1741+
/// Wrapper for std::accumulate with a binary operator.
1742+
template <typename R, typename E, typename BinaryOp>
1743+
auto accumulate(R &&Range, E &&Init, BinaryOp &&Op) {
1744+
return std::accumulate(adl_begin(Range), adl_end(Range),
1745+
std::forward<E>(Init), std::forward<BinaryOp>(Op));
1746+
}
1747+
1748+
/// Returns the sum of all values in `Range` with `Init` initial value.
1749+
/// The default initial value is 0.
1750+
template <typename R, typename E = detail::ValueOfRange<R>>
1751+
auto sum_of(R &&Range, E Init = E{0}) {
1752+
return accumulate(std::forward<R>(Range), std::move(Init));
1753+
}
1754+
1755+
/// Returns the product of all values in `Range` with `Init` initial value.
1756+
/// The default initial value is 1.
1757+
template <typename R, typename E = detail::ValueOfRange<R>>
1758+
auto product_of(R &&Range, E Init = E{1}) {
1759+
return accumulate(std::forward<R>(Range), std::move(Init),
1760+
std::multiplies<>{});
1761+
}
1762+
17411763
/// Provide wrappers to std::for_each which take ranges instead of having to
17421764
/// pass begin/end explicitly.
17431765
template <typename R, typename UnaryFunction>

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <array>
1515
#include <climits>
1616
#include <cstddef>
17+
#include <functional>
1718
#include <initializer_list>
1819
#include <iterator>
1920
#include <list>
@@ -1610,6 +1611,54 @@ TEST(STLExtrasTest, Accumulate) {
16101611
EXPECT_EQ(accumulate(V1, 10), std::accumulate(V1.begin(), V1.end(), 10));
16111612
EXPECT_EQ(accumulate(drop_begin(V1), 7),
16121613
std::accumulate(V1.begin() + 1, V1.end(), 7));
1614+
1615+
EXPECT_EQ(accumulate(V1, 2, std::multiplies<>{}), 240);
1616+
}
1617+
1618+
TEST(STLExtrasTest, SumOf) {
1619+
EXPECT_EQ(sum_of(std::vector<int>()), 0);
1620+
EXPECT_EQ(sum_of(std::vector<int>(), 1), 1);
1621+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1622+
static_assert(std::is_same_v<decltype(sum_of(V1)), int>);
1623+
static_assert(std::is_same_v<decltype(sum_of(V1, 1)), int>);
1624+
EXPECT_EQ(sum_of(V1), 15);
1625+
EXPECT_EQ(sum_of(V1, 1), 16);
1626+
1627+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1628+
static_assert(std::is_same_v<decltype(sum_of(V2)), float>);
1629+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0f), float>);
1630+
static_assert(std::is_same_v<decltype(sum_of(V2), 1.0), double>);
1631+
EXPECT_EQ(sum_of(V2), 7.0f);
1632+
EXPECT_EQ(sum_of(V2, 1.0f), 8.0f);
1633+
1634+
// Make sure that for a const argument the return value is non-const.
1635+
const std::vector<float> V3 = {1.0f, 2.0f};
1636+
static_assert(std::is_same_v<decltype(sum_of(V3)), float>);
1637+
EXPECT_EQ(sum_of(V3), 3.0f);
1638+
}
1639+
1640+
TEST(STLExtrasTest, ProductOf) {
1641+
EXPECT_EQ(product_of(std::vector<int>()), 1);
1642+
EXPECT_EQ(product_of(std::vector<int>(), 0), 0);
1643+
EXPECT_EQ(product_of(std::vector<int>(), 1), 1);
1644+
std::vector<int> V1 = {1, 2, 3, 4, 5};
1645+
static_assert(std::is_same_v<decltype(product_of(V1)), int>);
1646+
static_assert(std::is_same_v<decltype(product_of(V1, 1)), int>);
1647+
EXPECT_EQ(product_of(V1), 120);
1648+
EXPECT_EQ(product_of(V1, 1), 120);
1649+
EXPECT_EQ(product_of(V1, 2), 240);
1650+
1651+
std::vector<float> V2 = {1.0f, 2.0f, 4.0f};
1652+
static_assert(std::is_same_v<decltype(product_of(V2)), float>);
1653+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0f), float>);
1654+
static_assert(std::is_same_v<decltype(product_of(V2), 1.0), double>);
1655+
EXPECT_EQ(product_of(V2), 8.0f);
1656+
EXPECT_EQ(product_of(V2, 4.0f), 32.0f);
1657+
1658+
// Make sure that for a const argument the return value is non-const.
1659+
const std::vector<float> V3 = {1.0f, 2.0f};
1660+
static_assert(std::is_same_v<decltype(product_of(V3)), float>);
1661+
EXPECT_EQ(product_of(V3), 2.0f);
16131662
}
16141663

16151664
struct Foo;

0 commit comments

Comments
 (0)