|
23 | 23 | #include "sst/core/serialization/impl/serialize_utility.h" |
24 | 24 | #include "sst/core/warnmacros.h" |
25 | 25 |
|
| 26 | +#include <array> |
26 | 27 | #include <deque> |
27 | 28 | #include <forward_list> |
28 | 29 | #include <list> |
29 | 30 | #include <map> |
30 | 31 | #include <set> |
31 | 32 | #include <string> |
32 | 33 | #include <tuple> |
| 34 | +#include <type_traits> |
33 | 35 | #include <unordered_map> |
34 | 36 | #include <unordered_set> |
35 | 37 | #include <utility> |
| 38 | +#include <variant> |
36 | 39 | #include <vector> |
37 | 40 |
|
38 | 41 | namespace SST::CoreTestSerialization { |
@@ -223,6 +226,39 @@ checkContainerSerializeDeserialize(T*& data) |
223 | 226 | return true; |
224 | 227 | }; |
225 | 228 |
|
| 229 | +// std::variant |
| 230 | +auto checkVariant = [](auto& data, auto& result) { |
| 231 | + using T = std::decay_t<decltype(data)>; |
| 232 | + using R = std::decay_t<decltype(result)>; |
| 233 | + |
| 234 | + if constexpr ( !std::is_same_v<T, R> ) { |
| 235 | + // ignore the cases where T != R at compile-time, since they are excluded by index() runtime equality test |
| 236 | + return false; |
| 237 | + } |
| 238 | + else if constexpr ( std::is_same_v<T, std::string> || std::is_arithmetic_v<T> || std::is_enum_v<T> ) { |
| 239 | + return data == result; |
| 240 | + } |
| 241 | + else if constexpr ( std::is_same_v<T, std::vector<int>> ) { |
| 242 | + if ( data.size() != result.size() ) return false; |
| 243 | + for ( size_t i = 0; i < data.size(); ++i ) |
| 244 | + if ( data[i] != result[i] ) return false; |
| 245 | + return true; |
| 246 | + } |
| 247 | + else { |
| 248 | + static_assert(sizeof(T) == 0, "Unsupported type in checkVariant()"); |
| 249 | + } |
| 250 | +}; |
| 251 | + |
| 252 | +template <typename... Types> |
| 253 | +bool |
| 254 | +checkVariantSerializeDeserialize(std::variant<Types...>& data) |
| 255 | +{ |
| 256 | + std::variant<Types...> result; |
| 257 | + serializeDeserialize(data, result); |
| 258 | + if ( result.index() != data.index() ) return false; |
| 259 | + return std::visit(checkVariant, data, result); |
| 260 | +}; |
| 261 | + |
226 | 262 | // Arrays |
227 | 263 |
|
228 | 264 | template <typename> |
@@ -846,6 +882,40 @@ coreTestSerialization::coreTestSerialization(ComponentId_t id, Params& params) : |
846 | 882 | if ( !passed ) out.output("ERROR: unordered_multiset<int32_t>* did not serialize/deserialize properly\n"); |
847 | 883 | delete umultiset_in; |
848 | 884 | } |
| 885 | + else if ( test == "variant" ) { |
| 886 | + std::variant<std::vector<int>, double, std::string> var; |
| 887 | + for ( int ntry = 0; ntry < 5; ++ntry ) { |
| 888 | + bool passed = false; |
| 889 | + |
| 890 | + // Generate random variant each try |
| 891 | + switch ( rng->generateNextUInt32() % std::variant_size_v<decltype(var)> ) { |
| 892 | + case 0: |
| 893 | + { |
| 894 | + var = std::vector<int>(rng->generateNextUInt32() % 1000); |
| 895 | + for ( auto& e : std::get<0>(var) ) |
| 896 | + e = rng->generateNextInt32(); |
| 897 | + passed = checkVariantSerializeDeserialize(var); |
| 898 | + break; |
| 899 | + } |
| 900 | + case 1: |
| 901 | + { |
| 902 | + var = double(rng->generateNextInt32()); |
| 903 | + passed = checkVariantSerializeDeserialize(var); |
| 904 | + break; |
| 905 | + } |
| 906 | + case 2: |
| 907 | + { |
| 908 | + std::string str; |
| 909 | + size_t len = rng->generateNextUInt32() % 100; |
| 910 | + for ( size_t i = 0; i < len; ++i ) |
| 911 | + str += "0123456789"[rng->generateNextUInt32() % 10]; |
| 912 | + var = str; |
| 913 | + passed = checkVariantSerializeDeserialize(var); |
| 914 | + } |
| 915 | + } |
| 916 | + if ( !passed ) out.output("ERROR: std::variant<...> did not serialize/deserialize properly\n"); |
| 917 | + } |
| 918 | + } |
849 | 919 | else if ( test == "map_to_vector" ) { |
850 | 920 |
|
851 | 921 | // Containers to other containers |
|
0 commit comments