@@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
67
67
}
68
68
return msg_count;
69
69
}
70
+
71
+ inline bool is_little_endian ()
72
+ {
73
+ const uint16_t i = 0x01 ;
74
+ return *reinterpret_cast <const uint8_t *>(&i) == 0x01 ;
75
+ }
76
+
77
+ inline void write_network_order (unsigned char *buf, const uint32_t value)
78
+ {
79
+ if (is_little_endian ()) {
80
+ ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t >::max ();
81
+ *buf++ = (value >> 24 ) & mask;
82
+ *buf++ = (value >> 16 ) & mask;
83
+ *buf++ = (value >> 8 ) & mask;
84
+ *buf++ = value & mask;
85
+ } else {
86
+ std::memcpy (buf, &value, sizeof (value));
87
+ }
88
+ }
89
+
90
+ inline uint32_t read_u32_network_order (const unsigned char *buf)
91
+ {
92
+ if (is_little_endian ()) {
93
+ return (static_cast <uint32_t >(buf[0 ]) << 24 )
94
+ + (static_cast <uint32_t >(buf[1 ]) << 16 )
95
+ + (static_cast <uint32_t >(buf[2 ]) << 8 )
96
+ + static_cast <uint32_t >(buf[3 ]);
97
+ } else {
98
+ uint32_t value;
99
+ std::memcpy (&value, buf, sizeof (value));
100
+ return value;
101
+ }
102
+ }
70
103
} // namespace detail
71
104
72
105
/* Receive a multipart message.
@@ -189,42 +222,37 @@ message_t encode(const Range &parts)
189
222
190
223
// First pass check sizes
191
224
for (const auto &part : parts) {
192
- size_t part_size = part.size ();
225
+ const size_t part_size = part.size ();
193
226
if (part_size > std::numeric_limits<std::uint32_t >::max ()) {
194
227
// Size value must fit into uint32_t.
195
228
throw std::range_error (" Invalid size, message part too large" );
196
229
}
197
- size_t count_size = 5 ;
198
- if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
199
- count_size = 1 ;
200
- }
230
+ const size_t count_size =
231
+ part_size < std::numeric_limits<std::uint8_t >::max () ? 1 : 5 ;
201
232
mmsg_size += part_size + count_size;
202
233
}
203
234
204
235
message_t encoded (mmsg_size);
205
236
unsigned char *buf = encoded.data <unsigned char >();
206
237
for (const auto &part : parts) {
207
- uint32_t part_size = part.size ();
238
+ const uint32_t part_size = part.size ();
208
239
const unsigned char *part_data =
209
240
static_cast <const unsigned char *>(part.data ());
210
241
211
- // small part
212
242
if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
243
+ // small part
213
244
*buf++ = (unsigned char ) part_size;
214
- memcpy (buf, part_data, part_size);
215
- buf += part_size;
216
- continue ;
245
+ } else {
246
+ // big part
247
+ *buf++ = std::numeric_limits<uint8_t >::max ();
248
+ detail::write_network_order (buf, part_size);
249
+ buf += sizeof (part_size);
217
250
}
218
-
219
- // big part
220
- *buf++ = std::numeric_limits<uint8_t >::max ();
221
- *buf++ = (part_size >> 24 ) & std::numeric_limits<std::uint8_t >::max ();
222
- *buf++ = (part_size >> 16 ) & std::numeric_limits<std::uint8_t >::max ();
223
- *buf++ = (part_size >> 8 ) & std::numeric_limits<std::uint8_t >::max ();
224
- *buf++ = part_size & std::numeric_limits<std::uint8_t >::max ();
225
- memcpy (buf, part_data, part_size);
251
+ std::memcpy (buf, part_data, part_size);
226
252
buf += part_size;
227
253
}
254
+
255
+ assert (static_cast <size_t >(buf - encoded.data <unsigned char >()) == mmsg_size);
228
256
return encoded;
229
257
}
230
258
@@ -251,22 +279,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
251
279
while (source < limit) {
252
280
size_t part_size = *source++;
253
281
if (part_size == std::numeric_limits<std::uint8_t >::max ()) {
254
- if (source > limit - 4 ) {
282
+ if (static_cast < size_t >( limit - source) < sizeof ( uint32_t ) ) {
255
283
throw std::out_of_range (
256
284
" Malformed encoding, overflow in reading size" );
257
285
}
258
- part_size = (( uint32_t ) source[ 0 ] << 24 ) + (( uint32_t ) source[ 1 ] << 16 )
259
- + (( uint32_t ) source[ 2 ] << 8 ) + ( uint32_t ) source[ 3 ];
260
- source += 4 ;
286
+ part_size = detail::read_u32_network_order ( source);
287
+ // the part size is allowed to be less than 0xFF
288
+ source += sizeof ( uint32_t ) ;
261
289
}
262
290
263
- if (source > limit - part_size) {
291
+ if (static_cast < size_t >( limit - source) < part_size) {
264
292
throw std::out_of_range (" Malformed encoding, overflow in reading part" );
265
293
}
266
294
*out = message_t (source, part_size);
267
295
++out;
268
296
source += part_size;
269
297
}
298
+
299
+ assert (source == limit);
270
300
return out;
271
301
}
272
302
0 commit comments