diff --git a/Packet++/src/Packet.cpp b/Packet++/src/Packet.cpp index 4ff44a2ba6..a60ea2ce78 100644 --- a/Packet++/src/Packet.cpp +++ b/Packet++/src/Packet.cpp @@ -63,39 +63,50 @@ namespace pcpp m_FirstLayer = createFirstLayer(linkType); - m_LastLayer = m_FirstLayer; - Layer* curLayer = m_FirstLayer; - while (curLayer != nullptr && - (parseUntil == UnknownProtocol || !curLayer->isMemberOfProtocolFamily(parseUntil)) && - curLayer->getOsiModelLayer() <= parseUntilLayer) + // As the stop conditions are inclusive, the parse must go one layer further and then roll back if needed + bool rollbackLastLayer = false; + bool foundTargetProtocol = false; + for (auto* curLayer = m_FirstLayer; curLayer != nullptr; curLayer = curLayer->getNextLayer()) { - curLayer->parseNextLayer(); + // Mark the current layer as allocated in the packet curLayer->m_IsAllocatedInPacket = true; - curLayer = curLayer->getNextLayer(); - if (curLayer != nullptr) - m_LastLayer = curLayer; - } + m_LastLayer = curLayer; // Update last layer to current layer - if (curLayer != nullptr && curLayer->isMemberOfProtocolFamily(parseUntil)) - { - curLayer->m_IsAllocatedInPacket = true; - } + // If the current layer is of a higher OSI layer than the target, stop parsing + if (curLayer->getOsiModelLayer() > parseUntilLayer) + { + rollbackLastLayer = true; + break; + } - if (curLayer != nullptr && curLayer->getOsiModelLayer() > parseUntilLayer) - { - // don't delete the first layer. If already past the target layer, treat the same as if the layer was found. - if (curLayer == m_FirstLayer) + // If we are searching for a specific layer protocol, record when we find at least one target. + const bool matchesTarget = curLayer->isMemberOfProtocolFamily(parseUntil); + if (parseUntil != UnknownProtocol && matchesTarget) { - curLayer->m_IsAllocatedInPacket = true; + foundTargetProtocol = true; } - else + + // If we have found the target protocol already, we are parsing until we find a different protocol + if (foundTargetProtocol && !matchesTarget) { - m_LastLayer = curLayer->getPrevLayer(); - delete curLayer; - m_LastLayer->m_NextLayer = nullptr; + rollbackLastLayer = true; + break; } + + // Parse the next layer. This will update the next layer pointer of the current layer. + curLayer->parseNextLayer(); + } + + // Roll back one layer, if parsing with search condition as the conditions are inclusive. + // Don't delete the first layer. If already past the target layer, treat the same as if the layer was found. + if (rollbackLastLayer && m_LastLayer != m_FirstLayer) + { + m_LastLayer = m_LastLayer->getPrevLayer(); + delete m_LastLayer->m_NextLayer; + m_LastLayer->m_NextLayer = nullptr; } + // If there is data left in the raw packet that doesn't belong to any layer, create a PacketTrailerLayer if (m_LastLayer != nullptr && parseUntil == UnknownProtocol && parseUntilLayer == OsiModelLayerUnknown) { // find if there is data left in the raw packet that doesn't belong to any layer. In that case it's probably diff --git a/Tests/Packet++Test/TestDefinition.h b/Tests/Packet++Test/TestDefinition.h index 90e2c06f9a..9248e028d4 100644 --- a/Tests/Packet++Test/TestDefinition.h +++ b/Tests/Packet++Test/TestDefinition.h @@ -60,6 +60,7 @@ PTF_TEST_CASE(ResizeLayerTest); PTF_TEST_CASE(PrintPacketAndLayersTest); PTF_TEST_CASE(ProtocolFamilyMembershipTest); PTF_TEST_CASE(PacketParseLayerLimitTest); +PTF_TEST_CASE(PacketParseMultiLayerTest); // Implemented in HttpTests.cpp PTF_TEST_CASE(HttpRequestParseMethodTest); diff --git a/Tests/Packet++Test/Tests/PacketTests.cpp b/Tests/Packet++Test/Tests/PacketTests.cpp index d28e470868..ae53f5c60f 100644 --- a/Tests/Packet++Test/Tests/PacketTests.cpp +++ b/Tests/Packet++Test/Tests/PacketTests.cpp @@ -20,6 +20,7 @@ #include "PayloadLayer.h" #include "GeneralUtils.h" #include "SystemUtils.h" +#include "BgpLayer.h" using pcpp_tests::utils::createPacketFromHexResource; @@ -1064,3 +1065,33 @@ PTF_TEST_CASE(PacketParseLayerLimitTest) pcpp::Packet packet1(rawPacket1.get(), pcpp::OsiModelTransportLayer); PTF_ASSERT_EQUAL(packet1.getLastLayer()->getOsiModelLayer(), pcpp::OsiModelTransportLayer); } + +PTF_TEST_CASE(PacketParseMultiLayerTest) +{ + // The BGP packet has 4 BGP messages inside. + auto rawPacket = createPacketFromHexResource("PacketExamples/Bgp_update2.dat"); + + // Limit to BGP layer + pcpp::Packet packet(rawPacket.get(), pcpp::BGP); + + const size_t expectedNumOfBgpMessages = 4; + size_t actualNumOfBgpMessages = 0; + + pcpp::BgpLayer* bgpLayer = packet.getLayerOfType(); + if (bgpLayer != nullptr) + { + ++actualNumOfBgpMessages; + } + + // The fallback iteration uses expected * 2, just to be sure we won't get into an infinite loop + for (; bgpLayer != nullptr && actualNumOfBgpMessages < expectedNumOfBgpMessages * 2;) + { + bgpLayer = packet.getNextLayerOfType(bgpLayer); + if (bgpLayer != nullptr) + { + ++actualNumOfBgpMessages; + } + } + + PTF_ASSERT_EQUAL(actualNumOfBgpMessages, expectedNumOfBgpMessages); +} diff --git a/Tests/Packet++Test/main.cpp b/Tests/Packet++Test/main.cpp index 08fd40784a..5a6fd6cf1e 100644 --- a/Tests/Packet++Test/main.cpp +++ b/Tests/Packet++Test/main.cpp @@ -169,6 +169,7 @@ int main(int argc, char* argv[]) PTF_RUN_TEST(PrintPacketAndLayersTest, "packet;print"); PTF_RUN_TEST(ProtocolFamilyMembershipTest, "packet"); PTF_RUN_TEST(PacketParseLayerLimitTest, "packet"); + PTF_RUN_TEST(PacketParseMultiLayerTest, "packet"); PTF_RUN_TEST(HttpRequestParseMethodTest, "http"); PTF_RUN_TEST(HttpRequestLayerParsingTest, "http"); diff --git a/Tests/PcppTestUtilities/Resources.cpp b/Tests/PcppTestUtilities/Resources.cpp index 727afb3231..ef991223ce 100644 --- a/Tests/PcppTestUtilities/Resources.cpp +++ b/Tests/PcppTestUtilities/Resources.cpp @@ -51,7 +51,8 @@ namespace pcpp_tests } } // namespace - ResourceProvider::ResourceProvider(std::string dataRoot) : m_DataRoot(std::move(dataRoot)) + ResourceProvider::ResourceProvider(std::string dataRoot, bool frozen) + : m_DataRoot(std::move(dataRoot)), m_Frozen(frozen) {} Resource ResourceProvider::loadResource(const char* filename, ResourceType resourceType) const @@ -101,6 +102,56 @@ namespace pcpp_tests throw std::invalid_argument("Unsupported resource type"); } } + + void ResourceProvider::saveResource(ResourceType resourceType, const char* filename, const uint8_t* data, + size_t length) const + { + if (m_Frozen) + { + throw std::runtime_error("Resource provider is frozen and does not allow saving"); + } + + if (data == nullptr || length == 0) + { + throw std::invalid_argument("Data is null or length is zero"); + } + + std::string fullPath; + if (!m_DataRoot.empty()) + { + fullPath = m_DataRoot + getOsPathSeparator() + filename; + } + else + { + fullPath = filename; + } + + auto const requireOpen = [filename](std::ofstream const& fileStream) { + if (!fileStream) + { + throw std::runtime_error(std::string("Failed to open file: ") + filename); + } + }; + + switch (resourceType) + { + case ResourceType::HexData: + { + std::ofstream fileStream(fullPath); + requireOpen(fileStream); + for (size_t i = 0; i < length; ++i) + { + fileStream << std::hex; + fileStream.width(2); + fileStream.fill('0'); + fileStream << static_cast(data[i]); + } + break; + } + default: + throw std::invalid_argument("Unsupported resource type"); + } + } } // namespace utils namespace diff --git a/Tests/PcppTestUtilities/Resources.h b/Tests/PcppTestUtilities/Resources.h index ec4f2d8aa9..f18fce0274 100644 --- a/Tests/PcppTestUtilities/Resources.h +++ b/Tests/PcppTestUtilities/Resources.h @@ -29,7 +29,8 @@ namespace pcpp_tests public: /// @brief Constructs a ResourceProvider with a specified data root directory. /// @param dataRoot The root directory from which resources will be loaded. - explicit ResourceProvider(std::string dataRoot); + /// @param frozen If true, the provider is read-only and does not allow saving resources. + explicit ResourceProvider(std::string dataRoot, bool frozen = true); /// @brief Loads a resource from resource provider. /// @param filename The name of the resource file to load. @@ -43,8 +44,18 @@ namespace pcpp_tests /// @return A vector containing the loaded data. std::vector loadResourceToVector(const char* filename, ResourceType resourceType) const; + /// @brief Saves a resource to the resource provider. + /// @param resourceType The type of the resource being saved. + /// @param filename The name of the file to save the resource to. + /// @param data Pointer to the data to be saved. + /// @param length The length of the data in bytes. + /// @throw std::runtime_error if the provider is frozen and does not allow saving. + void saveResource(ResourceType resourceType, const char* filename, const uint8_t* data, + size_t length) const; + private: std::string m_DataRoot; ///< The root directory for test data files + bool m_Frozen = true; ///< Indicates if the provider is frozen (no modifications allowed) }; } // namespace utils