|
5 | 5 | #include "envoy/buffer/buffer.h" |
6 | 6 |
|
7 | 7 | #include "source/common/api/os_sys_calls_impl.h" |
| 8 | +#include "source/common/common/safe_memcpy.h" |
8 | 9 | #include "source/common/common/utility.h" |
9 | 10 | #include "source/common/event/file_event_impl.h" |
10 | 11 | #include "source/common/network/address_impl.h" |
@@ -280,6 +281,44 @@ absl::optional<uint32_t> maybeGetPacketsDroppedFromHeader([[maybe_unused]] const |
280 | 281 | return absl::nullopt; |
281 | 282 | } |
282 | 283 |
|
| 284 | +template <typename T> T getUnsignedIntFromHeader(const cmsghdr& cmsg) { |
| 285 | + static_assert(std::is_unsigned_v<T>, "return type must be unsigned integral"); |
| 286 | + T value; |
| 287 | + safeMemcpyUnsafeSrc(&value, CMSG_DATA(&cmsg)); |
| 288 | + return value; |
| 289 | +} |
| 290 | + |
| 291 | +template <typename T> absl::optional<T> maybeGetUnsignedIntFromHeader(const cmsghdr& cmsg) { |
| 292 | + static_assert(std::is_unsigned_v<T>, "return type must be unsigned integral"); |
| 293 | + switch (cmsg.cmsg_len) { |
| 294 | + case CMSG_LEN(sizeof(uint8_t)): |
| 295 | + return static_cast<T>(getUnsignedIntFromHeader<uint8_t>(cmsg)); |
| 296 | + case CMSG_LEN(sizeof(uint16_t)): |
| 297 | + return static_cast<T>(getUnsignedIntFromHeader<uint16_t>(cmsg)); |
| 298 | + case CMSG_LEN(sizeof(uint32_t)): |
| 299 | + return static_cast<T>(getUnsignedIntFromHeader<uint32_t>(cmsg)); |
| 300 | + case CMSG_LEN(sizeof(uint64_t)): |
| 301 | + return static_cast<T>(getUnsignedIntFromHeader<uint64_t>(cmsg)); |
| 302 | + default:; |
| 303 | + } |
| 304 | + IS_ENVOY_BUG( |
| 305 | + fmt::format("unexpected cmsg_len value for unsigned integer payload: {}", cmsg.cmsg_len)); |
| 306 | + return absl::nullopt; |
| 307 | +} |
| 308 | + |
| 309 | +absl::optional<uint8_t> maybeGetTosFromHeader(const cmsghdr& cmsg) { |
| 310 | + if ( |
| 311 | +#ifdef __APPLE__ |
| 312 | + (cmsg.cmsg_level == IPPROTO_IP && cmsg.cmsg_type == IP_RECVTOS) || |
| 313 | +#else |
| 314 | + (cmsg.cmsg_level == IPPROTO_IP && cmsg.cmsg_type == IP_TOS) || |
| 315 | +#endif // __APPLE__ |
| 316 | + (cmsg.cmsg_level == IPPROTO_IPV6 && cmsg.cmsg_type == IPV6_TCLASS)) { |
| 317 | + return maybeGetUnsignedIntFromHeader<uint8_t>(cmsg); |
| 318 | + } |
| 319 | + return absl::nullopt; |
| 320 | +} |
| 321 | + |
283 | 322 | Api::IoCallUint64Result IoSocketHandleImpl::recvmsg(Buffer::RawSlice* slices, |
284 | 323 | const uint64_t num_slice, uint32_t self_port, |
285 | 324 | const UdpSaveCmsgConfig& save_cmsg_config, |
@@ -359,17 +398,17 @@ Api::IoCallUint64Result IoSocketHandleImpl::recvmsg(Buffer::RawSlice* slices, |
359 | 398 | } |
360 | 399 | #ifdef UDP_GRO |
361 | 400 | if (cmsg->cmsg_level == SOL_UDP && cmsg->cmsg_type == UDP_GRO) { |
362 | | - output.msg_[0].gso_size_ = *reinterpret_cast<uint16_t*>(CMSG_DATA(cmsg)); |
| 401 | + absl::optional<uint16_t> maybe_gso = maybeGetUnsignedIntFromHeader<uint16_t>(*cmsg); |
| 402 | + if (maybe_gso) { |
| 403 | + output.msg_[0].gso_size_ = *maybe_gso; |
| 404 | + } |
363 | 405 | } |
364 | 406 | #endif |
365 | | - if (receive_ecn_ && |
366 | | -#ifdef __APPLE__ |
367 | | - ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_RECVTOS) || |
368 | | -#else |
369 | | - ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS) || |
370 | | -#endif // __APPLE__ |
371 | | - (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_TCLASS))) { |
372 | | - output.msg_[0].tos_ = *(reinterpret_cast<uint8_t*>(CMSG_DATA(cmsg))); |
| 407 | + if (receive_ecn_) { |
| 408 | + absl::optional<uint8_t> maybe_tos = maybeGetTosFromHeader(*cmsg); |
| 409 | + if (maybe_tos) { |
| 410 | + output.msg_[0].tos_ = *maybe_tos; |
| 411 | + } |
373 | 412 | } |
374 | 413 | } |
375 | 414 | } |
@@ -455,15 +494,12 @@ Api::IoCallUint64Result IoSocketHandleImpl::recvmmsg(RawSliceArrays& slices, uin |
455 | 494 | output.msg_[0].saved_cmsg_ = cmsg_slice; |
456 | 495 | } |
457 | 496 | Address::InstanceConstSharedPtr addr = maybeGetDstAddressFromHeader(*cmsg, self_port); |
458 | | - if (receive_ecn_ && |
459 | | -#ifdef __APPLE__ |
460 | | - ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_RECVTOS) || |
461 | | -#else |
462 | | - ((cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS) || |
463 | | -#endif // __APPLE__ |
464 | | - (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_TCLASS))) { |
465 | | - output.msg_[i].tos_ = *(reinterpret_cast<uint8_t*>(CMSG_DATA(cmsg))); |
466 | | - continue; |
| 497 | + if (receive_ecn_) { |
| 498 | + absl::optional<uint8_t> maybe_tos = maybeGetTosFromHeader(*cmsg); |
| 499 | + if (maybe_tos) { |
| 500 | + output.msg_[0].tos_ = *maybe_tos; |
| 501 | + continue; |
| 502 | + } |
467 | 503 | } |
468 | 504 | if (addr != nullptr) { |
469 | 505 | // This is a IP packet info message. |
|
0 commit comments