diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 814958a93..50eb2b7b3 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -205,11 +205,12 @@ cmake --build --preset tidy To autofix, use: ```bash -cmake --preset --preset tidy -DCMAKE_CXX_CLANG_TIDY="clang-tidy;--fix" -cmake --build --preset tidy -j1 +cmake --preset --preset tidy-fix +cmake --build --preset tidy-fix ``` -Remember to build single-threaded if applying fixes! +We also provide matching `--workflow`'s, but you'll need a newer CMake for that +(you can use pip to get it, though). ## Include what you use diff --git a/CMakeLists.txt b/CMakeLists.txt index e766348fa..d229821cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,8 +120,15 @@ option(BOOST_HISTOGRAM_ERRORS "Make warnings errors (for CI mostly)") # Adding warnings # Boost.Histogram doesn't pass sign -Wsign-conversion if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" OR "${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") - target_compile_options(_core PRIVATE -Wall -Wextra -pedantic-errors -Wconversion -Wsign-compare - -Wno-unused-value) + target_compile_options( + _core + PRIVATE -Wall + -Wextra + -pedantic-errors + -Wconversion + -Wsign-compare + -Wno-unused-value + -Wno-sign-conversion) if(BOOST_HISTOGRAM_ERRORS) target_compile_options(_core PRIVATE -Werror) endif() diff --git a/CMakePresets.json b/CMakePresets.json index 621e7f0ad..5f6cc352b 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -24,6 +24,14 @@ "cacheVariables": { "CMAKE_CXX_CLANG_TIDY": "clang-tidy;--warnings-as-errors=*" } + }, + { + "name": "tidy-fix", + "displayName": "Clang-tidy autofix", + "inherits": "tidy", + "cacheVariables": { + "CMAKE_CXX_CLANG_TIDY": "clang-tidy;--warnings-as-errors=*;--fix" + } } ], "buildPresets": [ @@ -38,6 +46,13 @@ "displayName": "Clang-tidy build", "configurePreset": "tidy", "nativeToolOptions": ["-k0"] + }, + { + "name": "tidy-fix", + "displayName": "Clang-tidy autofix build", + "configurePreset": "tidy-fix", + "jobs": 1, + "nativeToolOptions": ["-k0"] } ], "testPresets": [ @@ -59,6 +74,22 @@ { "type": "build", "name": "default" }, { "type": "test", "name": "default" } ] + }, + { + "name": "tidy", + "displayName": "Clang-tidy workflow", + "steps": [ + { "type": "configure", "name": "tidy" }, + { "type": "build", "name": "tidy" } + ] + }, + { + "name": "tidy-fix", + "displayName": "Clang-tidy autofix workflow", + "steps": [ + { "type": "configure", "name": "tidy-fix" }, + { "type": "build", "name": "tidy-fix" } + ] } ] } diff --git a/include/bh_python/fill.hpp b/include/bh_python/fill.hpp index 1f727f154..14256e943 100644 --- a/include/bh_python/fill.hpp +++ b/include/bh_python/fill.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -181,7 +182,7 @@ void fill_impl(bh::detail::accumulator_traits_holder, finalize_args(kwargs); // releasing gil here is safe, we don't manipulate refcounts - py::gil_scoped_release const lock; + const py::gil_scoped_release lock; variant::visit( overload([&h, &vargs](const variant::monostate&) { h.fill(vargs); }, [&h, &vargs](const auto& w) { h.fill(vargs, bh::weight(w)); }), @@ -203,7 +204,7 @@ void fill_impl(bh::detail::accumulator_traits_holder, throw std::invalid_argument("Sample array must be 1D"); // releasing gil here is safe, we don't manipulate refcounts - py::gil_scoped_release const lock; + const py::gil_scoped_release lock; variant::visit( overload([&h, &vargs, &sarray]( const variant::monostate&) { h.fill(vargs, bh::sample(sarray)); }, @@ -213,6 +214,34 @@ void fill_impl(bh::detail::accumulator_traits_holder, weight); } +// for multi_weight +template +void fill_impl(bh::detail::accumulator_traits_holder&>, + Histogram& h, + const VArgs& vargs, + const weight_t& weight, + py::kwargs& kwargs) { + boost::ignore_unused(weight); + auto s = required_arg(kwargs, "sample"); + finalize_args(kwargs); + auto sarray = py::cast>(s); + if(sarray.ndim() != 2) + throw std::invalid_argument("Sample array for MultiWeight must be 2D"); + + auto buf = sarray.request(); + // releasing gil here is safe, we don't manipulate refcounts + const py::gil_scoped_release lock; + const auto buf_shape0 = static_cast(buf.shape[0]); + const auto buf_shape1 = static_cast(buf.shape[1]); + auto* src = static_cast(buf.ptr); + std::vector> vec_s; + vec_s.reserve(buf_shape0); + for(std::size_t i = 0; i < buf_shape0; i++) { + vec_s.emplace_back(src + (i * buf_shape1), buf_shape1); + } + h.fill(vargs, bh::sample(vec_s)); +} + } // namespace detail template diff --git a/include/bh_python/histogram.hpp b/include/bh_python/histogram.hpp index 0025dd7db..23903aafc 100644 --- a/include/bh_python/histogram.hpp +++ b/include/bh_python/histogram.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -96,6 +97,15 @@ py::buffer_info make_buffer(bh::histogram>& return detail::make_buffer_impl(axes, flow, static_cast(buffer.ptr)); } +/// Specialization for multi_weight buffer +template +py::buffer_info make_buffer(bh::histogram>& h, bool flow) { + const auto& axes = bh::unsafe_access::axes(h); + auto& storage = bh::unsafe_access::storage(h); + return detail::make_buffer_impl( + axes, flow, static_cast(storage.get_buffer())); +} + /// Compute the bin of an array from a runtime list /// For example, [1,3,2] will return that bin of an array template diff --git a/include/bh_python/multi_weight.hpp b/include/bh_python/multi_weight.hpp new file mode 100644 index 000000000..d22716aae --- /dev/null +++ b/include/bh_python/multi_weight.hpp @@ -0,0 +1,306 @@ +#ifndef BOOST_HISTOGRAM_MULTI_WEIGHT_HPP +#define BOOST_HISTOGRAM_MULTI_WEIGHT_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace boost { +namespace histogram { + +template +struct multi_weight_base : public BASE { + using BASE::BASE; + + template + bool operator==(const S& values) const { + if(values.size() != this->size()) + return false; + + return std::equal(this->begin(), this->end(), values.begin()); + } + + template + bool operator!=(const S& values) const { + return !operator==(values); + } +}; + +template +struct multi_weight_reference : public multi_weight_base> { + // using boost::span::span; + using multi_weight_base>::multi_weight_base; + + void operator()(const boost::span values) { operator+=(values); } + + // template + // bool operator==(const S values) const { + // if(values.size() != this->size()) + // return false; + // + // return std::equal(this->begin(), this->end(), values.begin()); + //} + // + // template + // bool operator!=(const S values) const { + // return !operator==(values); + //} + + // void operator+=(const std::vector values) { + // operator+=(boost::span(values)); } + + void operator+=(const boost::span values) { + // template + // void operator+=(const S values) { + if(values.size() != this->size()) + throw std::range_error("size does not match for += ref"); + auto it = this->begin(); + for(const T& x : values) + *it++ += x; + } + + template + multi_weight_reference& operator=(const S& values) { + if(values.size() != this->size()) + throw std::range_error("size does not match for = ref"); + auto it = this->begin(); + for(const T& x : values) + *it++ = x; + return *this; + } +}; + +template +struct multi_weight_value : public multi_weight_base> { + using multi_weight_base>::multi_weight_base; + + explicit multi_weight_value(const boost::span values) { + this->assign(values.begin(), values.end()); + } + multi_weight_value() = default; + + void operator()(const boost::span& values) { operator+=(values); } + + // template + // bool operator==(const S values) const { + // if(values.size() != this->size()) + // return false; + // + // return std::equal(this->begin(), this->end(), values.begin()); + //} + // + // template + // bool operator!=(const S values) const { + // return !operator==(values); + //} + // + // void operator+=(const std::vector values) { + // operator+=(boost::span(values)); } + + // template + // void operator+=(const S values) { + void operator+=(const boost::span& values) { + if(values.size() != this->size()) { + if(this->size() > 0) { + throw std::range_error("size does not match for += val"); + } + this->assign(values.begin(), values.end()); + return; + } + auto it = this->begin(); + for(const T& x : values) + *it++ += x; + } + + template + multi_weight_value& operator=(const S values) { + this->assign(values.begin(), values.end()); + return *this; + } +}; + +template +class multi_weight { + public: + using element_type = ElementType; + using value_type = multi_weight_value; + using reference = multi_weight_reference; + using const_reference = const reference; + + template + struct iterator_base + : public detail::iterator_adaptor, + std::size_t, + Reference> { + using base_type + = detail::iterator_adaptor, + std::size_t, + Reference>; + + iterator_base() = default; + iterator_base(const iterator_base& other) + : iterator_base(other.par_, other.base()) {} + iterator_base(MWPtr par, std::size_t idx) + : base_type{idx} + , par_{par} {} + + iterator_base& operator=(const iterator_base& other) = default; + + decltype(auto) operator*() const { + return Reference{par_->buffer_.get() + this->base() * par_->nelem_, + par_->nelem_}; + } + + MWPtr par_ = nullptr; + }; + + using iterator = iterator_base; + using const_iterator + = iterator_base; + + static constexpr bool has_threading_support() { return false; } + + explicit multi_weight(const std::size_t k = 0) + : nelem_{k} {} + + multi_weight(const multi_weight& other) { *this = other; } + + multi_weight& operator=(const multi_weight& other) { + // Protect against self assignment + if(this == &other) { + return *this; + } + nelem_ = other.nelem_; + reset(other.size_); + std::copy( + other.buffer_.get(), other.buffer_.get() + size_ * nelem_, buffer_.get()); + return *this; + } + + std::size_t size() const { return size_; } + + std::size_t nelem() const { return nelem_; } + + void reset(std::size_t n) { + size_ = n; + buffer_.reset(new element_type[size_ * nelem_]); + default_fill(); + } + + template ::value, bool> = true> + void default_fill() {} + + template ::value, bool> = true> + void default_fill() { + std::fill_n(buffer_.get(), size_ * nelem_, 0); + } + + iterator begin() { return {this, 0}; } + iterator end() { return {this, size_}; } + + const_iterator begin() const { return {this, 0}; } + const_iterator end() const { return {this, size_}; } + + reference operator[](std::size_t i) { + return reference{buffer_.get() + i * nelem_, nelem_}; + } + const_reference operator[](std::size_t i) const { + return const_reference{buffer_.get() + i * nelem_, nelem_}; + } + + template + bool operator==(const multi_weight& other) const { + if(size_ * nelem_ != other.size_ * other.nelem_) + return false; + return std::equal( + buffer_.get(), buffer_.get() + size_ * nelem_, other.buffer_.get()); + } + + template + bool operator!=(const multi_weight& other) const { + return !operator==(other); + } + + template + void operator+=(const multi_weight& other) { + if(size_ * nelem_ != other.size_ * other.nelem_) { + throw std::range_error("size does not match"); + } + for(std::size_t i = 0; i < size_ * nelem_; i++) { + buffer_[i] += other.buffer_[i]; + } + } + + template + void serialize(Archive& ar, unsigned /* version */) { + ar& make_nvp("size", size_); + ar& make_nvp("nelem", nelem_); + std::vector w; + if(Archive::is_loading::value) { + ar& make_nvp("buffer", w); + reset(size_); + std::swap_ranges(buffer_.get(), buffer_.get() + size_ * nelem_, w.data()); + } else { + w.assign(buffer_.get(), buffer_.get() + size_ * nelem_); + ar& make_nvp("buffer", w); + } + } + + element_type* get_buffer() { return buffer_.get(); } + + private: + std::size_t size_ = 0; // Number of bins + std::size_t nelem_ = 0; // Number of weights per bin + std::unique_ptr buffer_; +}; + +template +std::ostream& operator<<(std::ostream& os, const multi_weight_value& v) { + os << "multi_weight_value("; + bool first = true; + for(const T& x : v) + if(first) { + first = false; + os << x; + } else + os << ", " << x; + os << ")"; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const multi_weight_reference& v) { + os << "multi_weight_reference("; + bool first = true; + for(const T& x : v) + if(first) { + first = false; + os << x; + } else + os << ", " << x; + os << ")"; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const multi_weight& v) { + os << "multi_weight(\n"; + int index = 0; + for(const multi_weight_reference& x : v) { + os << "Index " << index << ": " << x << "\n"; + index++; + } + os << ")"; + return os; +} + +} // namespace histogram +} // namespace boost + +#endif diff --git a/include/bh_python/register_histogram.hpp b/include/bh_python/register_histogram.hpp index 5f784eb1e..b369e8b68 100644 --- a/include/bh_python/register_histogram.hpp +++ b/include/bh_python/register_histogram.hpp @@ -155,13 +155,13 @@ auto register_histogram(py::module& m, const char* name, const char* desc) { py::keep_alive<0, 1>()) .def("at", - [](const histogram_t& self, py::args& args) -> value_type { + [](const histogram_t& self, const py::args& args) -> value_type { auto int_args = py::cast>(args); return self.at(int_args); }) .def("_at_set", - [](histogram_t& self, const value_type& input, py::args& args) { + [](histogram_t& self, const value_type& input, const py::args& args) { auto int_args = py::cast>(args); self.at(int_args) = input; }) @@ -171,7 +171,7 @@ auto register_histogram(py::module& m, const char* name, const char* desc) { .def( "sum", [](const histogram_t& self, bool flow) { - py::gil_scoped_release const release; + const py::gil_scoped_release release; return bh::algorithm::sum( self, flow ? bh::coverage::all : bh::coverage::inner); }, @@ -180,7 +180,7 @@ auto register_histogram(py::module& m, const char* name, const char* desc) { .def( "empty", [](const histogram_t& self, bool flow) { - py::gil_scoped_release const release; + const py::gil_scoped_release release; return bh::algorithm::empty( self, flow ? bh::coverage::all : bh::coverage::inner); }, @@ -190,14 +190,198 @@ auto register_histogram(py::module& m, const char* name, const char* desc) { [](const histogram_t& self, const py::args& args) { auto commands = py::cast>(args); - py::gil_scoped_release const release; + const py::gil_scoped_release release; return bh::algorithm::reduce(self, commands); }) .def("project", [](const histogram_t& self, const py::args& values) { auto cpp_values = py::cast>(values); - py::gil_scoped_release const release; + const py::gil_scoped_release release; + return bh::algorithm::project(self, cpp_values); + }) + + .def("fill", &fill) + + .def(make_pickle()) + + ; + + return hist; +} + +template <> +auto inline register_histogram>(py::module& m, + const char* name, + const char* desc) { + using S = bh::multi_weight; + using histogram_t = bh::histogram; + using value_type = std::vector; + + py::class_ hist(m, name, desc, py::buffer_protocol()); + + hist.def(py::init(), "axes"_a, "storage"_a = S()) + + .def_buffer( + [](histogram_t& h) -> py::buffer_info { return make_buffer(h, false); }) + + .def("rank", &histogram_t::rank) + .def("size", &histogram_t::size) + .def("reset", &histogram_t::reset) + + .def("__copy__", [](const histogram_t& self) { return histogram_t(self); }) + .def("__deepcopy__", + [](const histogram_t& self, const py::object& memo) { + auto* a = new histogram_t(self); + const py::module copy = py::module::import("copy"); + for(unsigned i = 0; i < a->rank(); i++) { + bh::unsafe_access::axis(*a, i).metadata() + = copy.attr("deepcopy")(a->axis(i).metadata(), memo); + } + return a; + }) + + .def(py::self += py::self) + + .def("__eq__", + [](const histogram_t& self, const py::object& other) { + try { + return self == py::cast(other); + } catch(const py::cast_error&) { + return false; + } + }) + .def("__ne__", + [](const histogram_t& self, const py::object& other) { + try { + return self != py::cast(other); + } catch(const py::cast_error&) { + return true; + } + }) + + .def_property_readonly_static( + "_storage_type", + [](const py::object&) { + return py::type::of(); + }) + + ; + +// Protection against an overzealous warning system +// https://bugs.llvm.org/show_bug.cgi?id=43124 +#ifdef __clang__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wself-assign-overloaded" +#endif + def_optionally(hist, + bh::detail::has_operator_rdiv{}, + py::self /= py::self); + def_optionally(hist, + bh::detail::has_operator_rmul{}, + py::self *= py::self); + def_optionally(hist, + bh::detail::has_operator_rsub{}, + py::self -= py::self); +#ifdef __clang__ +#pragma GCC diagnostic pop +#endif + + hist.def( + "to_numpy", + [](histogram_t& h, bool flow) { + py::tuple tup(1 + h.rank()); + + // Add the histogram buffer as the first argument + unchecked_set(tup, 0, py::array(make_buffer(h, flow))); + + // Add the axis edges + h.for_each_axis([&tup, flow, i = 0U](const auto& ax) mutable { + unchecked_set(tup, ++i, axis::edges(ax, flow, true)); + }); + + return tup; + }, + "flow"_a = false) + + .def( + "view", + [](const py::object& self, bool flow) { + auto& h = py::cast(self); + return py::array(make_buffer(h, flow), self); + }, + "flow"_a = false) + + .def( + "axis", + [](const histogram_t& self, int i) -> py::object { + unsigned const ii + = i < 0 ? self.rank() - static_cast(std::abs(i)) + : static_cast(i); + + if(ii < self.rank()) { + const axis_variant& var = self.axis(ii); + return bh::axis::visit( + [](auto&& item) -> py::object { + // Here we return a new, no-copy py::object that + // is not yet tied to the histogram. py::keep_alive + // is needed to make sure the histogram is alive as long + // as the axes references are. + return py::cast(item, py::return_value_policy::reference); + }, + var); + } + + throw std::out_of_range("The axis value must be less than the rank"); + }, + "i"_a = 0, + py::keep_alive<0, 1>()) + + .def("at", + [](const histogram_t& self, const py::args& args) -> value_type { + auto int_args = py::cast>(args); + auto at_value = self.at(int_args); + return {at_value.begin(), at_value.end()}; + }) + + .def("_at_set", + [](histogram_t& self, const value_type& input, const py::args& args) { + auto int_args = py::cast>(args); + self.at(int_args) = input; + }) + + .def("__repr__", &shift_to_string) + + .def( + "sum", + [](const histogram_t& self, bool flow) -> value_type { + const py::gil_scoped_release release; + return bh::algorithm::sum( + self, flow ? bh::coverage::all : bh::coverage::inner); + }, + "flow"_a = false) + + .def( + "empty", + [](const histogram_t& self, bool flow) { + const py::gil_scoped_release release; + return bh::algorithm::empty( + self, flow ? bh::coverage::all : bh::coverage::inner); + }, + "flow"_a = false) + + .def("reduce", + [](const histogram_t& self, const py::args& args) { + auto commands + = py::cast>(args); + const py::gil_scoped_release release; + return bh::algorithm::reduce(self, commands); + }) + + .def("project", + [](const histogram_t& self, const py::args& values) { + auto cpp_values = py::cast>(values); + const py::gil_scoped_release release; return bh::algorithm::project(self, cpp_values); }) diff --git a/include/bh_python/register_storage.hpp b/include/bh_python/register_storage.hpp index a701d3dd9..537193b35 100644 --- a/include/bh_python/register_storage.hpp +++ b/include/bh_python/register_storage.hpp @@ -71,3 +71,39 @@ py::class_ inline register_storage(py::module& m, return storage; } + +/// Add helpers to the multi_weight storage type +template <> +py::class_ inline register_storage(py::module& m, + const char* name, + const char* desc) { + using A = storage::multi_weight; // match code above + + py::class_ storage(m, name, desc); + + storage.def(py::init(), py::arg("k") = 0) + .def("__eq__", + [](const A& self, const py::object& other) { + try { + return self == py::cast(other); + } catch(const py::cast_error&) { + return false; + } + }) + .def("__ne__", + [](const A& self, const py::object& other) { + try { + return !(self == py::cast(other)); + } catch(const py::cast_error&) { + return true; + } + }) + .def(make_pickle()) + .def("__copy__", [](const A& self) { return A(self); }) + .def("__deepcopy__", [](const A& self, const py::object&) { return A(self); }) + .def_property_readonly("nelem", [](const A& self) { return self.nelem(); }) + + ; + + return storage; +} diff --git a/include/bh_python/storage.hpp b/include/bh_python/storage.hpp index 33f994468..7377aa637 100644 --- a/include/bh_python/storage.hpp +++ b/include/bh_python/storage.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -27,6 +28,7 @@ using atomic_int64 = bh::dense_storage>; using double_ = bh::dense_storage; using unlimited = bh::unlimited_storage<>; using weight = bh::dense_storage>; +using multi_weight = bh::multi_weight; using mean = bh::dense_storage>; using weighted_mean = bh::dense_storage>; @@ -61,6 +63,11 @@ inline const char* name() { return "weight"; } +template <> +inline const char* name() { + return "multi_weight"; +} + template <> inline const char* name() { return "mean"; diff --git a/src/boost_histogram/_core/hist.pyi b/src/boost_histogram/_core/hist.pyi index 7333e9646..2272fa74e 100644 --- a/src/boost_histogram/_core/hist.pyi +++ b/src/boost_histogram/_core/hist.pyi @@ -77,20 +77,13 @@ class any_mean(_BaseHistogram): def at(self, *args: int) -> accumulators.Mean: ... def _at_set(self, value: accumulators.Mean, *args: int) -> None: ... def sum(self, flow: bool = ...) -> accumulators.Mean: ... - def fill( - self, - *args: ArrayLike, - weight: ArrayLike | None = ..., - sample: ArrayLike | None = ..., - ) -> None: ... class any_weighted_mean(_BaseHistogram): def at(self, *args: int) -> accumulators.WeightedMean: ... def _at_set(self, value: accumulators.WeightedMean, *args: int) -> None: ... def sum(self, flow: bool = ...) -> accumulators.WeightedMean: ... - def fill( - self, - *args: ArrayLike, - weight: ArrayLike | None = ..., - sample: ArrayLike | None = ..., - ) -> None: ... + +class any_multi_weight(_BaseHistogram): + def at(self, *args: int) -> float: ... + def _at_set(self, value: float, *args: int) -> ArrayLike: ... + def sum(self, flow: bool = ...) -> ArrayLike: ... diff --git a/src/boost_histogram/_core/storage.pyi b/src/boost_histogram/_core/storage.pyi index 7153e45f2..63700d0d4 100644 --- a/src/boost_histogram/_core/storage.pyi +++ b/src/boost_histogram/_core/storage.pyi @@ -16,3 +16,7 @@ class unlimited(_BaseStorage): ... class weight(_BaseStorage): ... class mean(_BaseStorage): ... class weighted_mean(_BaseStorage): ... + +class multi_weight(_BaseStorage): + @property + def nelem(self) -> int: ... diff --git a/src/boost_histogram/histogram.py b/src/boost_histogram/histogram.py index 7d5a5e36b..87f50fafa 100644 --- a/src/boost_histogram/histogram.py +++ b/src/boost_histogram/histogram.py @@ -95,6 +95,7 @@ def __dir__() -> list[str]: _core.hist.any_weight, _core.hist.any_mean, _core.hist.any_weighted_mean, + _core.hist.any_multi_weight, } logger = logging.getLogger(__name__) diff --git a/src/boost_histogram/storage.py b/src/boost_histogram/storage.py index 387bd7a13..3e27f6632 100644 --- a/src/boost_histogram/storage.py +++ b/src/boost_histogram/storage.py @@ -12,6 +12,7 @@ "Double", "Int64", "Mean", + "MultiWeight", "Storage", "Unlimited", "Weight", @@ -73,3 +74,10 @@ class Mean(store.mean, Storage, family=boost_histogram): class WeightedMean(store.weighted_mean, Storage, family=boost_histogram): accumulator = accumulators.WeightedMean + + +class MultiWeight(store.multi_weight, Storage, family=boost_histogram): + accumulator = float + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.nelem})" diff --git a/src/register_histograms.cpp b/src/register_histograms.cpp index db43cc9e7..6a94bb7cf 100644 --- a/src/register_histograms.cpp +++ b/src/register_histograms.cpp @@ -50,4 +50,10 @@ void register_histograms(py::module& hist) { hist, "any_weighted_mean", "N-dimensional histogram for weighted and sampled data with any axis types."); + + register_histogram( + hist, + "any_multi_weight", + "N-dimensional histogram for storing multiple weights at once with any axis " + "types."); } diff --git a/src/register_storage.cpp b/src/register_storage.cpp index 47c71fce5..39cd3ccd3 100644 --- a/src/register_storage.cpp +++ b/src/register_storage.cpp @@ -35,4 +35,9 @@ void register_storages(py::module& storage) { storage, "weighted_mean", "Dense storage which tracks means of weighted samples in each cell"); + + register_storage( + storage, + "multi_weight", + "Dense storage which tracks sums of weights for multiple weights per entry"); } diff --git a/tests/test_storage.py b/tests/test_storage.py index 59308803c..3add63e15 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -392,3 +392,12 @@ def test_non_uniform_rebin_with_weights(): [1.0, 1.05, 1.15, 3.0], ) ) + + +def test_multi_weight(): + x = np.array([1, 2]) + weights = np.array([[1, 2, 3], [4, 5, 6]]) + h = bh.Histogram(bh.axis.Regular(5, 0, 5), storage=bh.storage.MultiWeight(3)) + h.fill(x, sample=weights) + assert_array_equal(h[1], [1, 2, 3]) + assert_array_equal(h[2], [4, 5, 6])