Skip to content

Commit 8066c4d

Browse files
committed
Add support for trigger comparison between two variables
1 parent 4e4d4e5 commit 8066c4d

File tree

2 files changed

+329
-37
lines changed

2 files changed

+329
-37
lines changed

src/sst/core/serialization/objectMap.h

Lines changed: 278 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <iostream>
2020
#include <map>
2121
#include <string>
22+
#include <type_traits>
2223
#include <typeinfo>
2324
#include <utility>
2425
#include <vector>
@@ -201,7 +202,7 @@ class ObjectMap
201202
Function that will get called when this object is deactivated
202203
(i.e selectParent() is called)
203204
*/
204-
virtual void deactivate_callback() {}
205+
virtual void deactivate_callback() {}
205206

206207
private:
207208
/**
@@ -329,6 +330,13 @@ class ObjectMap
329330
return nullptr;
330331
}
331332

333+
virtual ObjectMapComparison* getComparisonVar(const std::string& UNUSED(name), ObjectMapComparison::Op UNUSED(op),
334+
const std::string& UNUSED(name2), ObjectMap* UNUSED(var2))
335+
{
336+
printf("In virtual ObjectMapComparison\n");
337+
return nullptr;
338+
}
339+
332340
virtual ObjectBuffer* getObjectBuffer(const std::string& UNUSED(name), size_t UNUSED(sz)) { return nullptr; }
333341

334342
/************ Functions for walking the Object Hierarchy ************/
@@ -719,7 +727,7 @@ class ObjectMapClass : public ObjectMapWithChildren
719727

720728

721729
/**
722-
Templated implementation of ObjectMapComparison
730+
Template implementation of ObjectMapComparison for <var> <op> <value>
723731
*/
724732
template <typename T>
725733
class ObjectMapComparison_impl : public ObjectMapComparison
@@ -797,6 +805,175 @@ class ObjectMapComparison_impl : public ObjectMapComparison
797805
Op op_ = Op::INVALID;
798806
}; // class ObjectMapComparison_impl
799807

808+
/**
809+
Templated compareType implementations
810+
*/
811+
#if 1
812+
template <typename V1>
813+
bool
814+
cmp(V1 v, ObjectMapComparison::Op op, V1 w)
815+
{
816+
switch ( op ) {
817+
case ObjectMapComparison::Op::LT:
818+
return v < w;
819+
break;
820+
case ObjectMapComparison::Op::LTE:
821+
return v <= w;
822+
break;
823+
case ObjectMapComparison::Op::GT:
824+
return v > w;
825+
break;
826+
case ObjectMapComparison::Op::GTE:
827+
return v >= w;
828+
break;
829+
case ObjectMapComparison::Op::EQ:
830+
return v == w;
831+
break;
832+
case ObjectMapComparison::Op::NEQ:
833+
return v != w;
834+
break;
835+
default:
836+
std::cout << "Invalid comparison operator\n";
837+
return false;
838+
break;
839+
}
840+
}
841+
#endif
842+
843+
// Comparison of two variables of the same type
844+
template <typename U1, typename U2, std::enable_if_t<std::is_same_v<U1, U2>, int> = true>
845+
bool
846+
compareType(U1 v, ObjectMapComparison::Op op, U2 w)
847+
{
848+
// Handle same type - just compare
849+
printf(" CMP: Same type\n");
850+
return cmp(v, op, w);
851+
}
852+
853+
// Comparison of two variables with different arithmetic types
854+
template <typename U1, typename U2,
855+
std::enable_if_t<!std::is_same_v<U1, U2> && std::is_arithmetic_v<U1> && std::is_arithmetic_v<U2>, int> = true>
856+
bool
857+
compareType(U1 v, ObjectMapComparison::Op op, U2 w)
858+
{
859+
// printf(" CMP: Different types\n");
860+
// Handle integrals (bool, char, flavors of int)
861+
if ( std::is_integral_v<U1> && std::is_integral_v<U2> ) {
862+
// both unsigned integrals - cast to unsigned long long
863+
if ( std::is_unsigned_v<U1> && std::is_unsigned_v<U2> ) {
864+
printf(" CMP: Both unsigned integrals\n");
865+
unsigned long long v1 = static_cast<unsigned long long>(v);
866+
unsigned long long w1 = static_cast<unsigned long long>(w);
867+
return cmp(v1, op, w1);
868+
}
869+
// both integers but at least one signed - cast to signed long long
870+
else {
871+
printf(" CMP: Not both unsigned integrals\n");
872+
long long v1 = static_cast<long long>(v);
873+
long long w1 = static_cast<long long>(w);
874+
return cmp(v1, op, w1);
875+
}
876+
}
877+
// Handle float/double combinations - cast to long double
878+
else if ( std::is_floating_point_v<U1> && std::is_floating_point_v<U2> ) {
879+
printf(" CMP: Both fp\n");
880+
long double v1 = static_cast<long double>(v);
881+
long double w1 = static_cast<long double>(w);
882+
return cmp(v1, op, w1);
883+
}
884+
else { // Integral and FP comparison - cast integral to fp
885+
printf(" CMP: integral and fp\n");
886+
if ( std::is_integral_v<U1> ) {
887+
if ( std::is_same_v<U2, float> ) {
888+
float v1 = static_cast<float>(v);
889+
float w1 = static_cast<float>(w); // unnecessary but compiler needs to know they are the same
890+
return cmp(v1, op, w1);
891+
}
892+
else if ( std::is_same_v<U2, double> ) {
893+
double v1 = static_cast<double>(v);
894+
double w1 = static_cast<double>(w); // unnecessary ...
895+
return cmp(v1, op, w1);
896+
}
897+
else {
898+
long double v1 = static_cast<long double>(v);
899+
long double w1 = static_cast<long double>(w); // unnecessary ...
900+
return cmp(v1, op, w1);
901+
}
902+
}
903+
else {
904+
if ( std::is_same_v<U1, float> ) {
905+
float v1 = static_cast<float>(v); // unnecessary ...
906+
float w1 = static_cast<float>(w);
907+
return cmp(v1, op, w1);
908+
}
909+
else if ( std::is_same_v<U1, double> ) {
910+
double v1 = static_cast<double>(v); // unnecessary ...
911+
double w1 = static_cast<double>(w);
912+
return cmp(v1, op, w1);
913+
}
914+
else {
915+
long double v1 = static_cast<long double>(v); // unnecessary ...
916+
long double w1 = static_cast<long double>(w);
917+
return cmp(v1, op, w1);
918+
}
919+
}
920+
}
921+
}
922+
923+
// Comparison of two variables with at least one non-arithmetic type
924+
template <typename U1, typename U2,
925+
std::enable_if_t<!std::is_same_v<U1, U2> && !std::is_arithmetic_v<U1> || !std::is_arithmetic_v<U2>, int> = true>
926+
bool
927+
compareType(U1 UNUSED(v), ObjectMapComparison::Op UNUSED(op), U2 UNUSED(w))
928+
{
929+
// We shouldn't get here.... Can I throw an error somehow?
930+
printf(" ERROR: CMP: Does not support non-arithmetic types\n");
931+
return false;
932+
}
933+
934+
935+
/**
936+
Template implementation of ObjectMapComparison for <var> <op> <var>
937+
*/
938+
template <typename T1, typename T2>
939+
class ObjectMapComparison_var : public ObjectMapComparison
940+
{
941+
public:
942+
ObjectMapComparison_var(const std::string& name1, T1* var1, Op op, const std::string& name2, T2* var2) :
943+
ObjectMapComparison(name1),
944+
name2_(name2),
945+
var1_(var1),
946+
op_(op),
947+
var2_(var2)
948+
{}
949+
950+
bool compare() override
951+
{
952+
T1 v1 = *var1_;
953+
T2 v2 = *var2_;
954+
return compareType(v1, op_, v2);
955+
}
956+
957+
std::string getCurrentValue() override { return SST::Core::to_string(*var1_) + " " + SST::Core::to_string(*var2_); }
958+
959+
void* getVar() override { return var1_; }
960+
961+
void print() override
962+
{
963+
std::cout << name_ << " " << getStringFromOp(op_);
964+
if ( op_ == Op::CHANGED )
965+
std::cout << " ";
966+
else
967+
std::cout << " " << name2_ << " ";
968+
}
969+
970+
private:
971+
std::string name2_ = "";
972+
T1* var1_ = nullptr;
973+
Op op_ = Op::INVALID;
974+
T2* var2_ = nullptr;
975+
}; // class ObjectMapComparison_impl
976+
800977

801978
class ObjectBuffer
802979
{
@@ -1085,17 +1262,18 @@ class ObjectMapFundamental : public ObjectMap
10851262
*/
10861263
virtual void set_impl(const std::string& value) override { *addr_ = SST::Core::from_string<T>(value); }
10871264

1088-
virtual bool checkValue(const std::string& value) override {
1265+
virtual bool checkValue(const std::string& value) override
1266+
{
10891267
bool ret = false;
10901268
try {
10911269
T v = SST::Core::from_string<T>(value);
10921270
ret = static_cast<bool>(v);
10931271
}
1094-
catch (const std::invalid_argument& e) {
1272+
catch ( const std::invalid_argument& e ) {
10951273
std::cerr << "Error: Invalid value: " << value << std::endl;
10961274
return false;
10971275
}
1098-
catch (const std::out_of_range& e) {
1276+
catch ( const std::out_of_range& e ) {
10991277
std::cerr << "Error: Value is out of range: " << value << std::endl;
11001278
return false;
11011279
}
@@ -1150,12 +1328,105 @@ class ObjectMapFundamental : public ObjectMap
11501328
std::string getType() override { return demangle_name(typeid(T).name()); }
11511329

11521330
ObjectMapComparison* getComparison(
1153-
1154-
const std::string& name, ObjectMapComparison::Op op, const std::string& value) override
1331+
const std::string& name, ObjectMapComparison::Op UNUSED(op), const std::string& value) override
11551332
{
11561333
return new ObjectMapComparison_impl<T>(name, addr_, op, value);
11571334
}
11581335

1336+
ObjectMapComparison* getComparisonVar(
1337+
const std::string& name, ObjectMapComparison::Op op, const std::string& name2, ObjectMap* var2) override
1338+
{
1339+
// Ensure var2 is fundamental type
1340+
if ( !var2->isFundamental() ) {
1341+
printf("Triggers can only use fundamental types; %s is not "
1342+
"fundamental\n",
1343+
name2.c_str());
1344+
return nullptr;
1345+
}
1346+
1347+
#if 1
1348+
std::cout << "In ObjectMapComparison_var: " << name << " " << name2 << std::endl;
1349+
// std::cout << "typeid(T): " << demangle_name(typeid(T).name()) << std::endl;
1350+
std::string type1 = getType();
1351+
std::cout << "getType(v1): " << type1 << std::endl;
1352+
std::string type = var2->getType();
1353+
std::cout << "getType(v2): " << type << std::endl;
1354+
#endif
1355+
1356+
// Only support arithmetic types for now
1357+
if ( std::is_arithmetic_v<T> ) {
1358+
if ( type == "int" ) {
1359+
int* addr2 = static_cast<int*>(var2->getAddr());
1360+
return new ObjectMapComparison_var<T, int>(name, addr_, op, name2, addr2);
1361+
}
1362+
else if ( type == "unsigned int" ) {
1363+
unsigned int* addr2 = static_cast<unsigned int*>(var2->getAddr());
1364+
return new ObjectMapComparison_var<T, unsigned int>(name, addr_, op, name2, addr2);
1365+
}
1366+
else if ( type == "long" ) {
1367+
long* addr2 = static_cast<long*>(var2->getAddr());
1368+
return new ObjectMapComparison_var<T, long>(name, addr_, op, name2, addr2);
1369+
}
1370+
else if ( type == "unsigned long" ) {
1371+
unsigned long* addr2 = static_cast<unsigned long*>(var2->getAddr());
1372+
return new ObjectMapComparison_var<T, unsigned long>(name, addr_, op, name2, addr2);
1373+
}
1374+
else if ( type == "char" ) {
1375+
char* addr2 = static_cast<char*>(var2->getAddr());
1376+
return new ObjectMapComparison_var<T, char>(name, addr_, op, name2, addr2);
1377+
}
1378+
else if ( type == "signed char" ) {
1379+
signed char* addr2 = static_cast<signed char*>(var2->getAddr());
1380+
return new ObjectMapComparison_var<T, signed char>(name, addr_, op, name2, addr2);
1381+
}
1382+
else if ( type == "unsigned char" ) {
1383+
unsigned char* addr2 = static_cast<unsigned char*>(var2->getAddr());
1384+
return new ObjectMapComparison_var<T, unsigned char>(name, addr_, op, name2, addr2);
1385+
}
1386+
else if ( type == "short" ) {
1387+
short* addr2 = static_cast<short*>(var2->getAddr());
1388+
return new ObjectMapComparison_var<T, short>(name, addr_, op, name2, addr2);
1389+
}
1390+
else if ( type == "unsigned short" ) {
1391+
unsigned short* addr2 = static_cast<unsigned short*>(var2->getAddr());
1392+
return new ObjectMapComparison_var<T, unsigned short>(name, addr_, op, name2, addr2);
1393+
}
1394+
else if ( type == "long long" ) {
1395+
long long* addr2 = static_cast<long long*>(var2->getAddr());
1396+
return new ObjectMapComparison_var<T, long long>(name, addr_, op, name2, addr2);
1397+
}
1398+
else if ( type == "unsigned long long" ) {
1399+
unsigned long long* addr2 = static_cast<unsigned long long*>(var2->getAddr());
1400+
return new ObjectMapComparison_var<T, unsigned long long>(name, addr_, op, name2, addr2);
1401+
}
1402+
else if ( type == "bool" ) {
1403+
bool* addr2 = static_cast<bool*>(var2->getAddr());
1404+
return new ObjectMapComparison_var<T, bool>(name, addr_, op, name2, addr2);
1405+
}
1406+
else if ( type == "float" ) {
1407+
float* addr2 = static_cast<float*>(var2->getAddr());
1408+
return new ObjectMapComparison_var<T, float>(name, addr_, op, name2, addr2);
1409+
}
1410+
else if ( type == "double" ) {
1411+
double* addr2 = static_cast<double*>(var2->getAddr());
1412+
return new ObjectMapComparison_var<T, double>(name, addr_, op, name2, addr2);
1413+
}
1414+
else if ( type == "long double" ) {
1415+
long double* addr2 = static_cast<long double*>(var2->getAddr());
1416+
return new ObjectMapComparison_var<T, long double>(name, addr_, op, name2, addr2);
1417+
}
1418+
1419+
else {
1420+
std::cout << "Invalid type for comparison: " << name2 << "(" << type << ")\n";
1421+
return nullptr;
1422+
}
1423+
} // end if first var is arithmetic
1424+
else {
1425+
std::cout << "Invalid type for comparison: " << name2 << "(" << type << ")\n";
1426+
return nullptr;
1427+
}
1428+
}
1429+
11591430
ObjectBuffer* getObjectBuffer(const std::string& name, size_t sz) override
11601431
{
11611432
return new ObjectBuffer_impl<T>(name, addr_, sz);

0 commit comments

Comments
 (0)