@@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
67
67
}
68
68
return msg_count;
69
69
}
70
+
71
+ inline bool is_little_endian ()
72
+ {
73
+ const uint16_t i = 0x01 ;
74
+ return *reinterpret_cast <const uint8_t *>(&i) == 0x01 ;
75
+ }
76
+
77
+ inline void write_network_order (unsigned char *buf, const uint32_t value)
78
+ {
79
+ if (is_little_endian ()) {
80
+ ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t >::max ();
81
+ *buf++ = (value >> 24 ) & mask;
82
+ *buf++ = (value >> 16 ) & mask;
83
+ *buf++ = (value >> 8 ) & mask;
84
+ *buf++ = value & mask;
85
+ } else {
86
+ std::memcpy (buf, &value, sizeof (value));
87
+ }
88
+ }
89
+
90
+ inline uint32_t read_u32_network_order (const unsigned char *buf)
91
+ {
92
+ if (is_little_endian ()) {
93
+ return (static_cast <uint32_t >(buf[0 ]) << 24 )
94
+ + (static_cast <uint32_t >(buf[1 ]) << 16 )
95
+ + (static_cast <uint32_t >(buf[2 ]) << 8 )
96
+ + static_cast <uint32_t >(buf[3 ]);
97
+ } else {
98
+ uint32_t value;
99
+ std::memcpy (&value, buf, sizeof (value));
100
+ return value;
101
+ }
102
+ }
70
103
} // namespace detail
71
104
72
105
/* Receive a multipart message.
@@ -190,42 +223,37 @@ message_t encode(const Range &parts)
190
223
191
224
// First pass check sizes
192
225
for (const auto &part : parts) {
193
- size_t part_size = part.size ();
226
+ const size_t part_size = part.size ();
194
227
if (part_size > std::numeric_limits<std::uint32_t >::max ()) {
195
228
// Size value must fit into uint32_t.
196
229
throw std::range_error (" Invalid size, message part too large" );
197
230
}
198
- size_t count_size = 5 ;
199
- if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
200
- count_size = 1 ;
201
- }
231
+ const size_t count_size =
232
+ part_size < std::numeric_limits<std::uint8_t >::max () ? 1 : 5 ;
202
233
mmsg_size += part_size + count_size;
203
234
}
204
235
205
236
message_t encoded (mmsg_size);
206
237
unsigned char *buf = encoded.data <unsigned char >();
207
238
for (const auto &part : parts) {
208
- uint32_t part_size = part.size ();
239
+ const uint32_t part_size = part.size ();
209
240
const unsigned char *part_data =
210
241
static_cast <const unsigned char *>(part.data ());
211
242
212
- // small part
213
243
if (part_size < std::numeric_limits<std::uint8_t >::max ()) {
244
+ // small part
214
245
*buf++ = (unsigned char ) part_size;
215
- memcpy (buf, part_data, part_size);
216
- buf += part_size;
217
- continue ;
246
+ } else {
247
+ // big part
248
+ *buf++ = std::numeric_limits<uint8_t >::max ();
249
+ detail::write_network_order (buf, part_size);
250
+ buf += sizeof (part_size);
218
251
}
219
-
220
- // big part
221
- *buf++ = std::numeric_limits<uint8_t >::max ();
222
- *buf++ = (part_size >> 24 ) & std::numeric_limits<std::uint8_t >::max ();
223
- *buf++ = (part_size >> 16 ) & std::numeric_limits<std::uint8_t >::max ();
224
- *buf++ = (part_size >> 8 ) & std::numeric_limits<std::uint8_t >::max ();
225
- *buf++ = part_size & std::numeric_limits<std::uint8_t >::max ();
226
- memcpy (buf, part_data, part_size);
252
+ std::memcpy (buf, part_data, part_size);
227
253
buf += part_size;
228
254
}
255
+
256
+ assert (static_cast <size_t >(buf - encoded.data <unsigned char >()) == mmsg_size);
229
257
return encoded;
230
258
}
231
259
@@ -252,22 +280,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
252
280
while (source < limit) {
253
281
size_t part_size = *source++;
254
282
if (part_size == std::numeric_limits<std::uint8_t >::max ()) {
255
- if (source > limit - 4 ) {
283
+ if (static_cast < size_t >( limit - source) < sizeof ( uint32_t ) ) {
256
284
throw std::out_of_range (
257
285
" Malformed encoding, overflow in reading size" );
258
286
}
259
- part_size = (( uint32_t ) source[ 0 ] << 24 ) + (( uint32_t ) source[ 1 ] << 16 )
260
- + (( uint32_t ) source[ 2 ] << 8 ) + ( uint32_t ) source[ 3 ];
261
- source += 4 ;
287
+ part_size = detail::read_u32_network_order ( source);
288
+ // the part size is allowed to be less than 0xFF
289
+ source += sizeof ( uint32_t ) ;
262
290
}
263
291
264
- if (source > limit - part_size) {
292
+ if (static_cast < size_t >( limit - source) < part_size) {
265
293
throw std::out_of_range (" Malformed encoding, overflow in reading part" );
266
294
}
267
295
*out = message_t (source, part_size);
268
296
++out;
269
297
source += part_size;
270
298
}
299
+
300
+ assert (source == limit);
271
301
return out;
272
302
}
273
303
0 commit comments