Skip to content

Commit b495021

Browse files
authored
Merge pull request #404 from gummif/gfa/encoding-endian
Problem: No endian check in encoding
2 parents 9b824dd + 4784b74 commit b495021

File tree

2 files changed

+55
-25
lines changed

2 files changed

+55
-25
lines changed

tests/codec_multipart.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ TEST_CASE("multipart codec decode bad data overflow", "[codec_multipart]")
6969

7070
CHECK_THROWS_AS(
7171
multipart_t::decode(wrong_size),
72-
std::out_of_range);
72+
const std::out_of_range&);
7373
}
7474

7575
TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]")
@@ -83,7 +83,7 @@ TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]")
8383

8484
CHECK_THROWS_AS(
8585
multipart_t::decode(wrong_size),
86-
std::out_of_range);
86+
const std::out_of_range&);
8787
}
8888

8989

zmq_addon.hpp

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
6767
}
6868
return msg_count;
6969
}
70+
71+
inline bool is_little_endian()
72+
{
73+
const uint16_t i = 0x01;
74+
return *reinterpret_cast<const uint8_t *>(&i) == 0x01;
75+
}
76+
77+
inline void write_network_order(unsigned char *buf, const uint32_t value)
78+
{
79+
if (is_little_endian()) {
80+
ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t>::max();
81+
*buf++ = (value >> 24) & mask;
82+
*buf++ = (value >> 16) & mask;
83+
*buf++ = (value >> 8) & mask;
84+
*buf++ = value & mask;
85+
} else {
86+
std::memcpy(buf, &value, sizeof(value));
87+
}
88+
}
89+
90+
inline uint32_t read_u32_network_order(const unsigned char *buf)
91+
{
92+
if (is_little_endian()) {
93+
return (static_cast<uint32_t>(buf[0]) << 24)
94+
+ (static_cast<uint32_t>(buf[1]) << 16)
95+
+ (static_cast<uint32_t>(buf[2]) << 8)
96+
+ static_cast<uint32_t>(buf[3]);
97+
} else {
98+
uint32_t value;
99+
std::memcpy(&value, buf, sizeof(value));
100+
return value;
101+
}
102+
}
70103
} // namespace detail
71104

72105
/* Receive a multipart message.
@@ -190,42 +223,37 @@ message_t encode(const Range &parts)
190223

191224
// First pass check sizes
192225
for (const auto &part : parts) {
193-
size_t part_size = part.size();
226+
const size_t part_size = part.size();
194227
if (part_size > std::numeric_limits<std::uint32_t>::max()) {
195228
// Size value must fit into uint32_t.
196229
throw std::range_error("Invalid size, message part too large");
197230
}
198-
size_t count_size = 5;
199-
if (part_size < std::numeric_limits<std::uint8_t>::max()) {
200-
count_size = 1;
201-
}
231+
const size_t count_size =
232+
part_size < std::numeric_limits<std::uint8_t>::max() ? 1 : 5;
202233
mmsg_size += part_size + count_size;
203234
}
204235

205236
message_t encoded(mmsg_size);
206237
unsigned char *buf = encoded.data<unsigned char>();
207238
for (const auto &part : parts) {
208-
uint32_t part_size = part.size();
239+
const uint32_t part_size = part.size();
209240
const unsigned char *part_data =
210241
static_cast<const unsigned char *>(part.data());
211242

212-
// small part
213243
if (part_size < std::numeric_limits<std::uint8_t>::max()) {
244+
// small part
214245
*buf++ = (unsigned char) part_size;
215-
memcpy(buf, part_data, part_size);
216-
buf += part_size;
217-
continue;
246+
} else {
247+
// big part
248+
*buf++ = std::numeric_limits<uint8_t>::max();
249+
detail::write_network_order(buf, part_size);
250+
buf += sizeof(part_size);
218251
}
219-
220-
// big part
221-
*buf++ = std::numeric_limits<uint8_t>::max();
222-
*buf++ = (part_size >> 24) & std::numeric_limits<std::uint8_t>::max();
223-
*buf++ = (part_size >> 16) & std::numeric_limits<std::uint8_t>::max();
224-
*buf++ = (part_size >> 8) & std::numeric_limits<std::uint8_t>::max();
225-
*buf++ = part_size & std::numeric_limits<std::uint8_t>::max();
226-
memcpy(buf, part_data, part_size);
252+
std::memcpy(buf, part_data, part_size);
227253
buf += part_size;
228254
}
255+
256+
assert(static_cast<size_t>(buf - encoded.data<unsigned char>()) == mmsg_size);
229257
return encoded;
230258
}
231259

@@ -252,22 +280,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
252280
while (source < limit) {
253281
size_t part_size = *source++;
254282
if (part_size == std::numeric_limits<std::uint8_t>::max()) {
255-
if (source > limit - 4) {
283+
if (static_cast<size_t>(limit - source) < sizeof(uint32_t)) {
256284
throw std::out_of_range(
257285
"Malformed encoding, overflow in reading size");
258286
}
259-
part_size = ((uint32_t) source[0] << 24) + ((uint32_t) source[1] << 16)
260-
+ ((uint32_t) source[2] << 8) + (uint32_t) source[3];
261-
source += 4;
287+
part_size = detail::read_u32_network_order(source);
288+
// the part size is allowed to be less than 0xFF
289+
source += sizeof(uint32_t);
262290
}
263291

264-
if (source > limit - part_size) {
292+
if (static_cast<size_t>(limit - source) < part_size) {
265293
throw std::out_of_range("Malformed encoding, overflow in reading part");
266294
}
267295
*out = message_t(source, part_size);
268296
++out;
269297
source += part_size;
270298
}
299+
300+
assert(source == limit);
271301
return out;
272302
}
273303

0 commit comments

Comments
 (0)