Skip to content

Commit 11d5e45

Browse files
feat(rtbot): operator == implementation
1 parent 3e4c2e3 commit 11d5e45

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2019
-790
lines changed

libs/api/include/rtbot/Program.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ class Program {
114114

115115
// Message processing
116116
ProgramMsgBatch receive(const Message<NumberData>& msg, const std::string& port_id = "i1") {
117-
send_to_entry(msg, port_id);
117+
send_to_entry(msg, port_id, false);
118118
ProgramMsgBatch result = collect_outputs(false);
119119
clear_all_outputs();
120120
return result;
121121
}
122122

123123
ProgramMsgBatch receive_debug(const Message<NumberData>& msg, const std::string& port_id = "i1") {
124-
send_to_entry(msg, port_id);
124+
send_to_entry(msg, port_id, true);
125125
ProgramMsgBatch result = collect_outputs(true);
126126
clear_all_outputs();
127127
return result;
@@ -229,10 +229,10 @@ class Program {
229229
throw runtime_error("Could not resolve operator ID: " + id);
230230
}
231231

232-
void send_to_entry(const Message<NumberData>& msg, const std::string& port_id) {
232+
void send_to_entry(const Message<NumberData>& msg, const std::string& port_id, bool debug=false) {
233233
auto port_info = OperatorJson::parse_port_name(port_id);
234234
operators_[entry_operator_id_]->receive_data(create_message<NumberData>(msg.time, msg.data), port_info.index);
235-
operators_[entry_operator_id_]->execute();
235+
operators_[entry_operator_id_]->execute(debug);
236236
}
237237

238238
ProgramMsgBatch collect_outputs(bool debug_mode = false) {
@@ -255,7 +255,7 @@ class Program {
255255
// In debug mode, collect all ports
256256
if (debug_mode) {
257257
for (size_t i = 0; i < op->num_output_ports(); i++) {
258-
const auto& queue = op->get_output_queue(i);
258+
const auto& queue = op->get_debug_output_queue(i);
259259
if (!queue.empty()) {
260260
PortMsgBatch port_msgs;
261261
for (const auto& msg : queue) {

libs/core/include/rtbot/Buffer.h

Lines changed: 78 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,38 @@ class Buffer : public Operator {
7272
return std::sqrt(variance());
7373
}
7474

75-
Bytes collect() override {
76-
Bytes bytes = Operator::collect();
75+
bool equals(const Buffer& other) const {
76+
77+
if (window_size_ != other.window_size_) return false;
78+
79+
if (buffer_.size() != other.buffer_.size()) return false;
80+
81+
auto it1 = buffer_.begin();
82+
auto it2 = other.buffer_.begin();
83+
84+
for (; it1 != buffer_.end() && it2 != other.buffer_.end(); ++it1, ++it2) {
85+
const auto& msg1 = *it1;
86+
const auto& msg2 = *it2;
7787

78-
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&window_size_),
79-
reinterpret_cast<const uint8_t*>(&window_size_) + sizeof(window_size_));
88+
if (msg1 && msg2) {
89+
if (msg1->time != msg2->time) return false;
90+
if (msg1->hash() != msg2->hash()) return false;
91+
} else return false;
92+
}
93+
94+
if constexpr (Features::TRACK_SUM) {
95+
if (StateSerializer::hash_double(sum_) != StateSerializer::hash_double(other.sum_)) return false;
96+
}
97+
98+
if constexpr (Features::TRACK_VARIANCE) {
99+
if (StateSerializer::hash_double(M2_) != StateSerializer::hash_double(other.M2_)) return false;
100+
}
101+
102+
return Operator::equals(other);
103+
}
104+
105+
Bytes collect() override {
106+
Bytes bytes = Operator::collect();
80107

81108
size_t buffer_size = buffer_.size();
82109
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&buffer_size),
@@ -90,65 +117,68 @@ class Buffer : public Operator {
90117
bytes.insert(bytes.end(), msg_bytes.begin(), msg_bytes.end());
91118
}
92119

93-
if constexpr (Features::TRACK_SUM) {
94-
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&sum_),
95-
reinterpret_cast<const uint8_t*>(&sum_) + sizeof(sum_));
96-
}
97-
98-
if constexpr (Features::TRACK_VARIANCE) {
99-
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&M2_),
100-
reinterpret_cast<const uint8_t*>(&M2_) + sizeof(M2_));
101-
}
102-
103120
return bytes;
104121
}
105122

106123
void restore(Bytes::const_iterator& it) override {
107-
Operator::restore(it);
108-
109-
window_size_ = *reinterpret_cast<const size_t*>(&(*it));
110-
it += sizeof(size_t);
124+
// Call base restore first
125+
Operator::restore(it);
111126

112-
size_t buffer_size = *reinterpret_cast<const size_t*>(&(*it));
113-
it += sizeof(size_t);
127+
// ---- Read buffer_size safely ----
128+
size_t buffer_size;
129+
std::memcpy(&buffer_size, &(*it), sizeof(buffer_size));
130+
it += sizeof(buffer_size);
114131

132+
// ---- Deserialize buffer ----
115133
buffer_.clear();
116134
for (size_t i = 0; i < buffer_size; ++i) {
117-
size_t msg_size = *reinterpret_cast<const size_t*>(&(*it));
118-
it += sizeof(size_t);
119-
120-
Bytes msg_bytes(it, it + msg_size);
121-
buffer_.push_back(
122-
std::unique_ptr<Message<T>>(dynamic_cast<Message<T>*>(BaseMessage::deserialize(msg_bytes).release())));
123-
it += msg_size;
135+
// Read size of each message
136+
size_t msg_size;
137+
std::memcpy(&msg_size, &(*it), sizeof(msg_size));
138+
it += sizeof(msg_size);
139+
140+
// Extract message bytes
141+
Bytes msg_bytes(it, it + msg_size);
142+
143+
// Deserialize message and cast to derived type
144+
buffer_.push_back(
145+
std::unique_ptr<Message<T>>(
146+
dynamic_cast<Message<T>*>(BaseMessage::deserialize(msg_bytes).release())
147+
)
148+
);
149+
150+
it += msg_size;
124151
}
125152

153+
// ---- Optional statistics ----
126154
if constexpr (Features::TRACK_SUM) {
127-
sum_ = *reinterpret_cast<const double*>(&(*it));
128-
it += sizeof(double);
155+
sum_ = 0.0;
156+
if (!buffer_.empty()) {
157+
// First pass: compute sum
158+
for (const auto& msg : buffer_) {
159+
sum_ += msg->data.value;
160+
}
161+
}
129162
}
130163

131164
if constexpr (Features::TRACK_VARIANCE) {
132-
M2_ = *reinterpret_cast<const double*>(&(*it));
133-
it += sizeof(double);
134-
135-
// Recompute statistics from buffer to ensure consistency
136-
sum_ = 0.0;
137-
M2_ = 0.0;
138-
139-
if (!buffer_.empty()) {
140-
// First pass: compute mean
141-
for (const auto& msg : buffer_) {
142-
sum_ += msg->data.value;
165+
// Recompute statistics from buffer to ensure consistency
166+
sum_ = 0.0;
167+
M2_ = 0.0;
168+
169+
if (!buffer_.empty()) {
170+
// First pass: compute sum
171+
for (const auto& msg : buffer_) {
172+
sum_ += msg->data.value;
173+
}
174+
175+
// Second pass: compute M2
176+
double mean = sum_ / buffer_.size();
177+
for (const auto& msg : buffer_) {
178+
double delta = msg->data.value - mean;
179+
M2_ += delta * delta;
180+
}
143181
}
144-
145-
// Second pass: compute M2
146-
double mean = sum_ / buffer_.size();
147-
for (const auto& msg : buffer_) {
148-
double delta = msg->data.value - mean;
149-
M2_ += delta * delta;
150-
}
151-
}
152182
}
153183
}
154184

libs/core/include/rtbot/Demultiplexer.h

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,18 @@ class Demultiplexer : public Operator {
3535

3636
std::string type_name() const override { return "Demultiplexer"; }
3737

38-
size_t get_num_ports() const { return num_control_ports(); }
38+
size_t get_num_ports() const { return num_control_ports(); }
3939

40-
void receive_control(std::unique_ptr<BaseMessage> msg, size_t port_index) override {
41-
if (port_index >= num_control_ports()) {
42-
throw std::runtime_error("Invalid control port index");
43-
}
44-
45-
auto* ctrl_msg = dynamic_cast<const Message<BooleanData>*>(msg.get());
46-
if (!ctrl_msg) {
47-
throw std::runtime_error("Invalid control message type");
48-
}
49-
50-
// Update last timestamp
51-
control_ports_[port_index].last_timestamp = msg->time;
52-
53-
if (get_control_queue(port_index).size() == max_size_per_port_) {
54-
get_control_queue(port_index).pop_front();
55-
}
40+
bool equals(const Demultiplexer& other) const {
41+
return Operator::equals(other);
42+
}
43+
44+
bool operator==(const Demultiplexer& other) const {
45+
return equals(other);
46+
}
5647

57-
// Add message to queue
58-
get_control_queue(port_index).push_back(std::move(msg));
59-
control_ports_with_new_data_.insert(port_index);
48+
bool operator!=(const Demultiplexer& other) const {
49+
return !(*this == other);
6050
}
6151

6252
protected:

libs/core/include/rtbot/FilterByValue.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class FilterByValue : public Operator {
2121
add_output_port<T>();
2222
}
2323

24+
bool equals(const FilterByValue& other) const {
25+
return Operator::equals(other);
26+
}
27+
2428
protected:
2529
void process_data() override {
2630
auto& input_queue = get_data_queue(0);

libs/core/include/rtbot/Input.h

Lines changed: 12 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ class Input : public Operator {
2020
if (!PortType::is_valid_port_type(type)) {
2121
throw std::runtime_error("Unknown port type: " + type);
2222
}
23-
PortType::add_port(*this, type, true, true);
24-
last_sent_times_.push_back(0);
23+
PortType::add_port(*this, type, true, false ,true);
2524
port_type_names_.push_back(type);
2625
}
2726
}
@@ -31,68 +30,17 @@ class Input : public Operator {
3130
// Get port configuration
3231
const std::vector<std::string>& get_port_types() const { return port_type_names_; }
3332

34-
// Query port state
35-
bool has_sent(size_t port_index) const {
36-
validate_port_index(port_index);
37-
return last_sent_times_[port_index] > 0;
33+
bool equals(const Input& other) const {
34+
if (port_type_names_ != other.port_type_names_) return false;
35+
return Operator::equals(other);
3836
}
39-
40-
timestamp_t get_last_sent_time(size_t port_index) const {
41-
validate_port_index(port_index);
42-
return last_sent_times_[port_index];
43-
}
44-
45-
// State serialization
46-
Bytes collect() override {
47-
// First collect base state
48-
Bytes bytes = Operator::collect();
49-
50-
// Serialize last sent times
51-
size_t num_ports = last_sent_times_.size();
52-
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&num_ports),
53-
reinterpret_cast<const uint8_t*>(&num_ports) + sizeof(num_ports));
54-
55-
for (const auto& time : last_sent_times_) {
56-
bytes.insert(bytes.end(), reinterpret_cast<const uint8_t*>(&time),
57-
reinterpret_cast<const uint8_t*>(&time) + sizeof(time));
58-
}
59-
60-
// Serialize port type names
61-
StateSerializer::serialize_string_vector(bytes, port_type_names_);
62-
63-
return bytes;
64-
}
65-
66-
void restore(Bytes::const_iterator& it) override {
67-
// First restore base state
68-
Operator::restore(it);
69-
70-
// Restore last sent times
71-
size_t num_ports = *reinterpret_cast<const size_t*>(&(*it));
72-
it += sizeof(size_t);
73-
74-
StateSerializer::validate_port_count(num_ports, num_data_ports(), "Data");
75-
76-
last_sent_times_.clear();
77-
last_sent_times_.reserve(num_ports);
78-
for (size_t i = 0; i < num_ports; ++i) {
79-
timestamp_t time = *reinterpret_cast<const timestamp_t*>(&(*it));
80-
it += sizeof(timestamp_t);
81-
last_sent_times_.push_back(time);
82-
}
83-
84-
// Restore port type names
85-
StateSerializer::deserialize_string_vector(it, port_type_names_);
86-
87-
// Validate port types match
88-
if (port_type_names_.size() != num_data_ports()) {
89-
throw std::runtime_error("Port type count mismatch during restore");
90-
}
37+
38+
bool operator==(const Input& other) const {
39+
return equals(other);
9140
}
9241

93-
void reset() override {
94-
Operator::reset();
95-
last_sent_times_.assign(last_sent_times_.size(), 0);
42+
bool operator!=(const Input& other) const {
43+
return !(*this == other);
9644
}
9745

9846
// Do not throw exceptions in receive_data
@@ -107,36 +55,29 @@ class Input : public Operator {
10755
protected:
10856
void process_data() override {
10957
// Process each port independently to allow concurrent timestamps
110-
for (const auto& port_index : data_ports_with_new_data_) {
58+
for (int port_index = 0; port_index < num_data_ports(); port_index++) {
11159
const auto& input_queue = get_data_queue(port_index);
11260
if (input_queue.empty()) continue;
11361

11462
auto& output_queue = get_output_queue(port_index);
11563

11664
// Process all messages in input queue
117-
for (const auto& msg : input_queue) {
118-
// Only forward if timestamp is increasing for this specific port
119-
if (!has_sent(port_index) || msg->time > last_sent_times_[port_index]) {
120-
output_queue.push_back(std::move(msg->clone()));
121-
last_sent_times_[port_index] = msg->time;
122-
}
65+
for (const auto& msg : input_queue) {
66+
output_queue.push_back(std::move(msg->clone()));
12367
}
12468

12569
// Clear processed messages
12670
get_data_queue(port_index).clear();
12771
}
12872
}
12973

130-
void process_control() override {} // No control processing needed
13174

13275
private:
13376
void validate_port_index(size_t port_index) const {
13477
if (port_index >= num_data_ports()) {
13578
throw std::runtime_error("Invalid port index: " + std::to_string(port_index));
13679
}
13780
}
138-
139-
std::vector<timestamp_t> last_sent_times_;
14081
std::vector<std::string> port_type_names_;
14182
};
14283

0 commit comments

Comments
 (0)