Skip to content

Commit f3b650b

Browse files
committed
[tmva][sofie] not operator changes
1 parent 2d69bc3 commit f3b650b

File tree

6 files changed

+60
-22
lines changed

6 files changed

+60
-22
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ enum class EBasicUnaryOperator {
1313
kReciprocal,
1414
kSqrt,
1515
kNeg,
16+
kNot,
1617
kExp,
1718
kLog,
1819
kSin,
@@ -43,6 +44,12 @@ struct UnaryOpTraits<T, EBasicUnaryOperator::kNeg> {
4344
static std::string Op(const std::string &X) { return "-" + X; }
4445
};
4546

47+
template <typename T>
48+
struct UnaryOpTraits<T, EBasicUnaryOperator::kNot> {
49+
static std::string Name() { return "Not"; }
50+
static std::string Op(const std::string &X) { return "!" + X; }
51+
};
52+
4653
template <typename T>
4754
struct UnaryOpTraits<T, EBasicUnaryOperator::kExp> {
4855
static std::string Name() { return "Exp"; }
@@ -122,7 +129,14 @@ public:
122129
out << SP << "\n//---- Operator" << UnaryOpTraits<T, Op>::Name() << " " << OpName << "\n";
123130
size_t length = ConvertShapeToLength(fShapeX);
124131
out << SP << "for (size_t i = 0; i < " << length << "; i++) {\n";
125-
out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
132+
133+
// since NOT is operated on a boolean vector, for which we do not use a pointer
134+
if (Op == EBasicUnaryOperator::kNot){
135+
out << SP << SP << "fTensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("fTensor_" + fNX + "[i]") << ";\n";
136+
} else {
137+
out << SP << SP << "tensor_" << fNY << "[i] = " << UnaryOpTraits<T, Op>::Op("tensor_" + fNX + "[i]") << ";\n";
138+
}
139+
126140
out << SP << "}\n";
127141
return out.str();
128142
}

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -529,22 +529,20 @@ TEST(ONNX, Neg)
529529

530530
TEST(ONNX, Not)
531531
{
532-
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
533-
534532
// Preparing the standard input
535-
std::vector<float> input({-0.7077, 1.0645, -0.8607, 0.2085, 4.5335, -3.4592});
533+
std::vector<bool> input({false, true, false, false, true, true});
536534

537535
TMVA_SOFIE_Not::Session s("Not_FromONNX.dat");
538-
std::vector<float> output = s.infer(input.data());
536+
std::vector<bool> output = s.infer(input.data());
539537

540538
// Checking output size
541-
EXPECT_EQ(output.size(), sizeof(Not_ExpectedOutput::outputs) / sizeof(float));
539+
EXPECT_EQ(output.size(), sizeof(Not_ExpectedOutput::outputs) / sizeof(bool));
542540

543-
float *correct = Not_ExpectedOutput::outputs;
541+
bool *correct = Not_ExpectedOutput::outputs;
544542

545543
// Checking every output value, one by one
546544
for (size_t i = 0; i < output.size(); ++i) {
547-
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
545+
EXPECT_EQ(output[i], correct[i]);
548546
}
549547
}
550548

tmva/sofie/test/input_models/Not.onnx

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11

2-
 onnx-example:M
3-

4-
input1output"NotNotGraphZ
5-
input1
6-
7-

8-
b
9-
output
10-
11-

12-
B
2+
onnx-equal-not-example:�
3+
%
4+
input1
5+
input2 equal_output"Equal
6+

7+
equal_outputoutput"NotEqualNotGraphZ
8+
input1
9+

10+

11+
Z
12+
input2
13+

14+

15+
b
16+
output
17+
 
18+

19+
B
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
namespace Not_ExpectedOutput {
2-
float outputs[] = {0.7077, -1.0645, 0.8607, -0.2085, -4.5335, 3.4592};
2+
bool outputs[] = {true, false, true, true, false, false};
33
} // namespace Not_ExpectedOutput

tmva/sofie_parsers/src/ParseBasicUnary.cxx

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,26 @@ std::unique_ptr<ROperator> ParseBasicUnary(RModelParser_ONNX &parser, const onnx
2020
" but its type is not yet registered");
2121
}
2222

23+
if constexpr (Op == EBasicUnaryOperator::kNot) {
24+
if (input_type != ETensorType::BOOL) {
25+
throw std::runtime_error("TMVA::SOFIE - Unary NOT operator expects input type BOOL, got type " +
26+
std::to_string(static_cast<int>(input_type)));
27+
}
28+
}
29+
2330
std::unique_ptr<ROperator> op;
2431
std::string output_name = nodeproto.output(0);
25-
32+
2633
switch (input_type) {
2734
case ETensorType::FLOAT:
2835
op.reset(new ROperator_BasicUnary<float, Op>(input_name, output_name));
2936
break;
37+
case ETensorType::BOOL:
38+
if constexpr (Op == EBasicUnaryOperator::kNot) {
39+
op.reset(new ROperator_BasicUnary<bool, Op>(input_name, output_name));
40+
break;
41+
}
42+
[[fallthrough]];
3043
default:
3144
throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " +
3245
std::to_string(static_cast<int>(input_type)));
@@ -55,6 +68,11 @@ ParserFuncSignature ParseNeg = [](RModelParser_ONNX &parser, const onnx::NodePro
5568
return ParseBasicUnary<EBasicUnaryOperator::kNeg>(parser, nodeproto);
5669
};
5770

71+
// Parse Not
72+
ParserFuncSignature ParseNot = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
73+
return ParseBasicUnary<EBasicUnaryOperator::kNot>(parser, nodeproto);
74+
};
75+
5876
// Parse Exp
5977
ParserFuncSignature ParseExp = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
6078
return ParseBasicUnary<EBasicUnaryOperator::kExp>(parser, nodeproto);

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace SOFIE {
2020
extern ParserFuncSignature ParseSqrt;
2121
extern ParserFuncSignature ParseReciprocal;
2222
extern ParserFuncSignature ParseNeg;
23+
extern ParserFuncSignature ParseNot;
2324
extern ParserFuncSignature ParseExp;
2425
extern ParserFuncSignature ParseLog;
2526
extern ParserFuncSignature ParseSin;
@@ -161,7 +162,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
161162
RegisterOperator("Sqrt", ParseSqrt);
162163
RegisterOperator("Reciprocal", ParseReciprocal);
163164
RegisterOperator("Neg", ParseNeg);
164-
RegisterOperator("Not", ParseNeg);
165+
RegisterOperator("Not", ParseNot);
165166
RegisterOperator("Exp", ParseExp);
166167
RegisterOperator("Log", ParseLog);
167168
RegisterOperator("Sin", ParseSin);

0 commit comments

Comments
 (0)