Skip to content

Commit 8f4b21c

Browse files
committed
refactor: Update SegmentTree initialization to use SegmentTreeStoreSon for improved structure
1 parent 8d72133 commit 8f4b21c

File tree

2 files changed

+77
-68
lines changed

2 files changed

+77
-68
lines changed

test/point_set_range_composite.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ int main() {
3636
affines.emplace_back(a, b);
3737
}
3838

39-
SegmentTree<AffineMonoid> sgt(affines);
39+
SegmentTree<SegmentTreeStoreSon<AffineMonoid>> sgt(affines);
4040
while (q--) {
4141
size_t op;
4242
cin >> op;

weilycoder/ds/segment_tree.hpp

Lines changed: 76 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,122 +12,134 @@
1212
#include <vector>
1313

1414
namespace weilycoder {
15-
/**
16-
* @brief Segment Tree (point update and range query)
17-
* @tparam Monoid The monoid defining the operation and identity
18-
* @tparam ptr_t The type used for indexing (default: size_t)
19-
*/
20-
template <typename Monoid, typename ptr_t = size_t> struct SegmentTree {
21-
using T = typename Monoid::value_type;
15+
template <typename _Monoid, typename _ptr_t = size_t> struct SegmentTreeStoreSon {
16+
protected:
17+
using T = typename _Monoid::value_type;
18+
using ptr_t = _ptr_t;
19+
using Monoid = _Monoid;
2220
static constexpr ptr_t null = std::numeric_limits<ptr_t>::max();
2321

24-
/**
25-
* @brief Node structure for the segment tree
26-
*/
22+
private:
2723
struct Node {
2824
T value;
2925
ptr_t left, right;
3026

3127
Node() : value(Monoid::identity()), left(null), right(null) {}
3228
};
3329

34-
private:
35-
ptr_t tl, tr;
30+
ptr_t st, ed;
3631
std::vector<Node> data;
3732

38-
void pushup(size_t node) {
39-
data[node].value =
40-
Monoid::operation(data[data[node].left].value, data[data[node].right].value);
41-
}
42-
43-
ptr_t init(ptr_t l, ptr_t r) {
33+
ptr_t build(ptr_t l, ptr_t r) {
4434
ptr_t node = data.size();
4535
data.emplace_back();
46-
4736
if (r - l > 1) {
4837
ptr_t mid = l + ((r - l) >> 1);
49-
ptr_t left = init(l, mid), right = init(mid, r);
38+
ptr_t left = build(l, mid), right = build(mid, r);
5039
data[node].left = left, data[node].right = right;
5140
}
52-
5341
return node;
5442
}
5543

5644
ptr_t init(ptr_t l, ptr_t r, const std::vector<T> &arr) {
5745
ptr_t node = data.size();
5846
data.emplace_back();
59-
60-
if (r - l == 1)
47+
if (r - l == 1) {
6148
data[node].value = arr[l];
62-
else {
49+
} else {
6350
ptr_t mid = l + ((r - l) >> 1);
64-
ptr_t left = init(l, mid, arr);
65-
ptr_t right = init(mid, r, arr);
51+
ptr_t left = init(l, mid, arr), right = init(mid, r, arr);
6652
data[node].left = left, data[node].right = right;
6753
pushup(node);
6854
}
69-
7055
return node;
7156
}
7257

73-
void build(ptr_t l, ptr_t r) {
74-
if (r - l > 0) {
75-
data.reserve((r - l) * 2 - 1);
76-
init(l, r);
77-
}
58+
protected:
59+
ptr_t get_st() const { return st; }
60+
ptr_t get_ed() const { return ed; }
61+
62+
T &get_value(ptr_t node) { return data[node].value; }
63+
const T &get_value(ptr_t node) const { return data[node].value; }
64+
65+
ptr_t get_lc(ptr_t node) const { return data[node].left; }
66+
ptr_t get_rc(ptr_t node) const { return data[node].right; }
67+
68+
void pushdown(ptr_t node) const {}
69+
void pushup(ptr_t node) {
70+
data[node].value =
71+
Monoid::operation(data[data[node].left].value, data[data[node].right].value);
7872
}
7973

80-
void build(const std::vector<T> &arr) {
81-
if (!arr.empty()) {
82-
data.reserve(arr.size() * 2 - 1);
83-
init(0, arr.size(), arr);
84-
}
74+
explicit SegmentTreeStoreSon(ptr_t size) : st(0), ed(size) {
75+
data.reserve(size * 2 - 1);
76+
build(st, ed);
77+
}
78+
79+
explicit SegmentTreeStoreSon(ptr_t st, ptr_t ed) : st(st), ed(ed) {
80+
data.reserve((ed - st) * 2 - 1);
81+
build(st, ed);
8582
}
8683

84+
explicit SegmentTreeStoreSon(const std::vector<T> &arr)
85+
: st(0), ed(static_cast<ptr_t>(arr.size())) {
86+
data.reserve(arr.size() * 2 - 1);
87+
init(0, arr.size(), arr);
88+
}
89+
};
90+
91+
template <class SegmentBase> struct SegmentTree : private SegmentBase {
92+
using Monoid = typename SegmentBase::Monoid;
93+
using ptr_t = typename SegmentBase::ptr_t;
94+
using T = typename Monoid::value_type;
95+
static constexpr ptr_t null = SegmentBase::null;
96+
97+
private:
8798
void point_set(ptr_t node, ptr_t l, ptr_t r, ptr_t pos, const T &val) {
8899
if (r - l == 1)
89-
data[node].value = val;
100+
SegmentBase::get_value(node) = val;
90101
else {
91102
ptr_t mid = l + ((r - l) >> 1);
92103
if (pos < mid)
93-
point_set(data[node].left, l, mid, pos, val);
104+
point_set(SegmentBase::get_lc(node), l, mid, pos, val);
94105
else
95-
point_set(data[node].right, mid, r, pos, val);
96-
pushup(node);
106+
point_set(SegmentBase::get_rc(node), mid, r, pos, val);
107+
SegmentBase::pushup(node);
97108
}
98109
}
99110

100111
void point_update(ptr_t node, ptr_t l, ptr_t r, ptr_t pos, const T &val) {
101112
if (r - l == 1)
102-
data[node].value = Monoid::operation(data[node].value, val);
113+
SegmentBase::get_value(node) =
114+
Monoid::operation(SegmentBase::get_value(node), val);
103115
else {
104116
ptr_t mid = l + ((r - l) >> 1);
105117
if (pos < mid)
106-
point_update(data[node].left, l, mid, pos, val);
118+
point_update(SegmentBase::get_lc(node), l, mid, pos, val);
107119
else
108-
point_update(data[node].right, mid, r, pos, val);
109-
pushup(node);
120+
point_update(SegmentBase::get_rc(node), mid, r, pos, val);
121+
SegmentBase::pushup(node);
110122
}
111123
}
112124

113125
T point_query(ptr_t node, ptr_t l, ptr_t r, ptr_t pos) const {
114126
if (r - l == 1)
115-
return data[node].value;
127+
return SegmentBase::get_value(node);
116128
ptr_t mid = l + ((r - l) >> 1);
117129
if (pos < mid)
118-
return point_query(data[node].left, l, mid, pos);
130+
return point_query(SegmentBase::get_lc(node), l, mid, pos);
119131
else
120-
return point_query(data[node].right, mid, r, pos);
132+
return point_query(SegmentBase::get_rc(node), mid, r, pos);
121133
}
122134

123135
T range_query(ptr_t node, ptr_t l, ptr_t r, ptr_t ql, ptr_t qr) const {
124136
if (ql >= r || qr <= l)
125137
return Monoid::identity();
126138
if (ql <= l && r <= qr)
127-
return data[node].value;
139+
return SegmentBase::get_value(node);
128140
ptr_t mid = l + ((r - l) >> 1);
129-
T left_res = range_query(data[node].left, l, mid, ql, qr);
130-
T right_res = range_query(data[node].right, mid, r, ql, qr);
141+
T left_res = range_query(SegmentBase::get_lc(node), l, mid, ql, qr);
142+
T right_res = range_query(SegmentBase::get_rc(node), mid, r, ql, qr);
131143
return Monoid::operation(left_res, right_res);
132144
}
133145

@@ -136,33 +148,30 @@ template <typename Monoid, typename ptr_t = size_t> struct SegmentTree {
136148
* @brief Constructs a SegmentTree with given size
137149
* @param size The size of the array
138150
*/
139-
explicit SegmentTree(ptr_t size) : tl(0), tr(size) { build(tl, tr); }
151+
explicit SegmentTree(ptr_t size) : SegmentBase(size) {}
140152

141153
/**
142154
* @brief Constructs a SegmentTree for the range [left, right)
143155
* @param left The left index (inclusive)
144156
* @param right The right index (exclusive)
145157
*/
146-
explicit SegmentTree(ptr_t left, ptr_t right) : tl(left), tr(right) { build(tl, tr); }
158+
explicit SegmentTree(ptr_t left, ptr_t right) : SegmentBase(left, right) {}
147159

148160
/**
149161
* @brief Constructs a SegmentTree from an initial array
150162
* @param arr Initial array of elements
151163
*/
152-
explicit SegmentTree(const std::vector<T> &arr)
153-
: tl(0), tr(static_cast<ptr_t>(arr.size())) {
154-
build(arr);
155-
}
164+
explicit SegmentTree(const std::vector<T> &arr) : SegmentBase(arr) {}
156165

157166
/**
158167
* @brief Sets the value at position pos to val
159168
* @param pos The position to update
160169
* @param val The new value
161170
*/
162171
void point_set(ptr_t pos, const T &val) {
163-
if (pos < tl || pos >= tr)
172+
if (pos < get_st() || pos >= get_ed())
164173
throw std::out_of_range("SegmentTree::point_set: position out of range");
165-
point_set(0, tl, tr, pos, val);
174+
point_set(0, get_st(), get_ed(), pos, val);
166175
}
167176

168177
/**
@@ -171,9 +180,9 @@ template <typename Monoid, typename ptr_t = size_t> struct SegmentTree {
171180
* @param val The value to combine
172181
*/
173182
void point_update(ptr_t pos, const T &val) {
174-
if (pos < tl || pos >= tr)
183+
if (pos < get_st() || pos >= get_ed())
175184
throw std::out_of_range("SegmentTree::point_update: position out of range");
176-
point_update(0, tl, tr, pos, val);
185+
point_update(0, get_st(), get_ed(), pos, val);
177186
}
178187

179188
/**
@@ -182,9 +191,9 @@ template <typename Monoid, typename ptr_t = size_t> struct SegmentTree {
182191
* @return The value at position pos
183192
*/
184193
T point_query(ptr_t pos) const {
185-
if (pos < tl || pos >= tr)
194+
if (pos < get_st() || pos >= get_ed())
186195
throw std::out_of_range("SegmentTree::point_query: position out of range");
187-
return point_query(0, tl, tr, pos);
196+
return point_query(0, get_st(), get_ed(), pos);
188197
}
189198

190199
/**
@@ -194,13 +203,13 @@ template <typename Monoid, typename ptr_t = size_t> struct SegmentTree {
194203
* @return The result of the monoid operation over the range
195204
*/
196205
T range_query(ptr_t left, ptr_t right) const {
197-
if (left < tl || right > tr || left > right)
206+
if (left < get_st() || right > get_ed() || left > right)
198207
throw std::out_of_range("SegmentTree::range_query: range out of bounds");
199-
return range_query(0, tl, tr, left, right);
208+
return range_query(0, get_st(), get_ed(), left, right);
200209
}
201210

202-
ptr_t left() const { return tl; }
203-
ptr_t right() const { return tr; }
211+
ptr_t get_st() const { return SegmentBase::get_st(); }
212+
ptr_t get_ed() const { return SegmentBase::get_ed(); }
204213
};
205214
} // namespace weilycoder
206215

0 commit comments

Comments
 (0)