Skip to content

Commit 8817a96

Browse files
committed
[tmva][sofie] Add support for NonZero operator and fix handling of booleans
Correct handle booleans types when parsing from ONNX
1 parent 0e1a5e4 commit 8817a96

File tree

13 files changed

+321
-17
lines changed

13 files changed

+321
-17
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
6666
TMVA/ROperator_ScatterElements.hxx
6767
TMVA/ROperator_Gather.hxx
6868
TMVA/ROperator_GatherND.hxx
69+
TMVA/ROperator_NonZero.hxx
6970
TMVA/SOFIE_common.hxx
7071
TMVA/SOFIEHelpers.hxx
7172

tmva/sofie/inc/TMVA/ROperator_Constant.hxx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ public:
122122
model.AddConstantTensor(fNY, fShape, fValues);
123123
if (model.Verbose()) {
124124
std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
125-
<< " and values [";
126-
for (auto v : fValues) std::cout << " " << v;
127-
std::cout << "]" << std::endl;
125+
<< " and values " << ConvertValuesToString(fValues) << std::endl;
128126
}
129127
} else {
130128
model.AddIntermediateTensor(fNY, ConvertStringToType(TensorType<T>::Name()), fDimOutputShape);

tmva/sofie/inc/TMVA/ROperator_GatherND.hxx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,7 @@ public:
207207
std::string idIndex;
208208
for (size_t j = 0; j < fBatchDims; j++) {
209209
std::string index = "i_" + std::to_string(j);
210-
for (size_t k = 0; k <= j; k++)
211-
out << SP;
210+
for (size_t k = 0; k <= j; k++) out << SP;
212211
out << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
213212
if (j > 0) {
214213
outIndex += " + ";
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#ifndef TMVA_SOFIE_ROPERATOR_NONZERO
2+
#define TMVA_SOFIE_ROPERATOR_NONZERO
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
template<class T>
15+
class ROperator_NonZero final : public ROperator
16+
{
17+
18+
private:
19+
20+
std::string fNX;
21+
std::string fNY;
22+
std::vector<Dim> fShapeX;
23+
std::vector<Dim> fShapeY;
24+
25+
public:
26+
ROperator_NonZero(){}
27+
ROperator_NonZero(std::string nameX, std::string nameY):
28+
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
29+
fInputTensorNames = { fNX };
30+
fOutputTensorNames = { fNY };
31+
}
32+
33+
34+
35+
void Initialize(RModel& model) override {
36+
if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor
37+
throw std::runtime_error("TMVA SOFIE NonZero Op Input Tensor " + fNX + " is not found in model");
38+
}
39+
40+
41+
// case input is constant
42+
if (model.IsConstantTensor(fNX)) {
43+
// compute output directly
44+
T * data = static_cast<T*>(model.GetInitializedTensorData(fNX).get());
45+
// shape is fully known
46+
auto shapeX = model.GetTensorShape(fNX);
47+
std::vector<size_t> shapeY(2);
48+
shapeY[0] = shapeX.size();
49+
auto length = ConvertShapeToLength(shapeX);
50+
auto strides = UTILITY::ComputeStrideFromShape(shapeX);
51+
std::vector<std::vector<int64_t>> nonzero_indices;
52+
for (size_t i = 0; i < length; i++) {
53+
if (data[i] != 0) {
54+
// get indices
55+
size_t flat_index = i;
56+
std::vector<int64_t> indices(shapeX.size());
57+
for (size_t j = 0; j < shapeX.size(); ++j) {
58+
indices[j] = flat_index / strides[j];
59+
flat_index %= strides[j];
60+
}
61+
nonzero_indices.emplace_back(indices);
62+
}
63+
}
64+
shapeY[1] = nonzero_indices.size();
65+
std::vector<int64_t> dataY(shapeY[0]* shapeY[1]);
66+
size_t k = 0;
67+
for (size_t i = 0; i < shapeY[0]; i++) {
68+
for (size_t j = 0; j < shapeY[1]; j++) {
69+
dataY[k] = nonzero_indices[j][i];
70+
k++;
71+
}
72+
}
73+
if (dataY.empty()) {
74+
// no zero elements found
75+
dataY.resize(1);
76+
shapeY.clear(); // use an empty shape
77+
}
78+
79+
model.AddConstantTensor(fNY, shapeY, dataY);
80+
if (model.Verbose()) {
81+
std::cout << "NonZero : " << fNX << " -> " << fNY << " " << ConvertShapeToString(shapeY)
82+
<< " : " << ConvertValuesToString(dataY) << std::endl;
83+
}
84+
fIsOutputConstant = true;
85+
86+
} else {
87+
88+
fShapeX = model.GetDimTensorShape(fNX);
89+
90+
// output shape(-1) depends on number of elements of non zero values
91+
// first dim is rank of input
92+
fShapeY.resize(2);
93+
fShapeY[0] = fShapeX.size();
94+
95+
// identify as -1 since we will declare maximum as size of input
96+
fShapeY[1] = Dim{std::string("v_NonZero_") + fNX, static_cast<size_t>(-1)};
97+
98+
model.AddIntermediateTensor(fNY, ETensorType::INT64, fShapeY);
99+
if (model.Verbose()) {
100+
std::cout << "NonZero : " << fNX << " -> " << fNY << " " << ConvertShapeToString(fShapeY) << std::endl;
101+
}
102+
}
103+
}
104+
std::string GenerateSessionMembersCode(std::string /*opName*/) override {
105+
if (fIsOutputConstant) return "";
106+
// define output value used as max non zero with max size = input shape * N
107+
auto inputLength = ConvertDimShapeToLength(fShapeX);
108+
std::stringstream out;
109+
out << SP << "size_t v_NonZero_" << fNX << " = " << inputLength << ";\n";
110+
return out.str();
111+
}
112+
113+
114+
std::string Generate(std::string opName) override {
115+
if (fIsOutputConstant) {
116+
return "";
117+
}
118+
opName = "op_" + opName;
119+
if (fShapeX.empty()) {
120+
throw std::runtime_error("TMVA SOFIE Operator NonZero called to Generate without being initialized first");
121+
}
122+
std::stringstream out;
123+
auto inputLength = ConvertDimShapeToLength(fShapeX);
124+
auto maxStrideY = inputLength;
125+
size_t dims = fShapeX.size();
126+
out << "\n//------ NonZero\n";
127+
128+
std::string vnonzero = "v_NonZero_" + fNX;
129+
130+
// loop on input indices
131+
out << "size_t offset_" << opName << " = 0;\n";
132+
out << vnonzero << " = 0;\n";
133+
for (size_t j = 0; j < dims; j++) {
134+
std::string index = "i_" + std::to_string(j);
135+
for (size_t k = 0; k <= j; k++) out << SP;
136+
out << "for (size_t " << index << " = 0; " << index << " < " << fShapeX[j] << "; " << index << "++) {\n";
137+
}
138+
for (size_t k = 0; k <= dims; k++) out << SP;
139+
out << "if (tensor_" << fNX << "[offset_" << opName << "]) {\n";
140+
for (size_t k = 0; k <= dims+1; k++) out << SP;
141+
out << vnonzero << "++;\n";
142+
for (size_t j = 0; j < dims; j++) {
143+
for (size_t k = 0; k <= dims+1; k++) out << SP;
144+
out << "tensor_" << fNY << "[" << maxStrideY << " * " << j << " + " << vnonzero << "] = i_" << j << ";\n";
145+
}
146+
for (size_t k = 0; k <= dims; k++) out << SP;
147+
out << "}\n";
148+
//end loops
149+
for (size_t j = dims; j > 0; j--) {
150+
for (size_t k = 0; k <j; k++) out << SP;
151+
out << "}\n";
152+
}
153+
// now we need to rearrange the vector if nonzero is less than length of input
154+
out << SP << "if (" << vnonzero << " < " << inputLength << "){\n";
155+
for (size_t j = 1; j < dims; j++) {
156+
out << SP << SP << "std::copy(tensor_" << fNY;
157+
if (j>0) out << " + " << maxStrideY;
158+
if (j>1) out << " * " << j;
159+
out << ", tensor_" << fNY;
160+
if (j>0) out << " + " << maxStrideY;
161+
if (j>1) out << " * " << j;
162+
out << " + " << vnonzero << ", tensor_" << fNY;
163+
if (j>0) out << " + " << vnonzero;
164+
if (j>1) out << "* " << j;
165+
out << ");\n";
166+
}
167+
out << SP << "}\n";
168+
169+
return out.str();
170+
}
171+
172+
};
173+
174+
}//SOFIE
175+
}//Experimental
176+
}//TMVA
177+
178+
179+
#endif //TMVA_SOFIE_ROPERATOR_NonZero

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ template<>
164164
struct TensorType<bool> {
165165
static const std::string Name() { return "bool"; }
166166
};
167+
template<>
168+
struct TensorType<int8_t> {
169+
static const std::string Name() { return "int8_t"; }
170+
};
167171

168172
struct TensorMemoryInfo {
169173
std::string_view tensor_name;
@@ -225,8 +229,11 @@ std::string ConvertValuesToString(size_t n, const T * data) {
225229
ret << "{ ";
226230
for (size_t i = 0; i < n; i++) {
227231
if (std::is_floating_point_v<T>)
228-
ret << std::setprecision(std::numeric_limits<T>::max_digits10);
229-
ret << data[i];
232+
ret << std::setprecision(std::numeric_limits<T>::max_digits10) << data[i];
233+
else
234+
// cast in case of boolean (int8)
235+
ret << (int64_t) data[i];
236+
230237
if (i < n-1) ret << ", ";
231238
}
232239
ret << "}";

tmva/sofie/src/SOFIE_common.cxx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,21 @@ std::string ConvertDimShapeToLength(const std::vector<Dim> & shape) {
143143
std::string length;
144144
// case of empty vectors return 1
145145
if (shape.empty()) return "1";
146-
size_t int_length = 0;
146+
int64_t int_length = -1;
147147
for (size_t i = 0; i < shape.size(); i++) {
148148
if (shape[i].isParam) {
149149
if (!length.empty()) length += " * ";
150150
length += shape[i].param;
151151
} else {
152-
if (int_length == 0)
152+
if (int_length == -1)
153153
int_length = shape[i].dim;
154154
else
155155
int_length *= shape[i].dim;
156156
}
157157
}
158158
// multiply the integer components to the parametric one
159-
// if larger than 1
160-
if (int_length > 0) {
159+
// if larger than 1 - otherwise returns -1
160+
if (int_length >= 0) {
161161
if (!length.empty() && int_length > 1) {
162162
length += " * ";
163163
length += std::to_string(int_length);

tmva/sofie/test/TestCustomModelsFromONNX.cxx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@
327327

328328
#include "ScatterElements_FromONNX.hxx"
329329

330+
#include "NonZero_FromONNX.hxx"
331+
#include "NonZero_Constant_FromONNX.hxx"
332+
330333
#include "gtest/gtest.h"
331334

332335
constexpr float DEFAULT_TOLERANCE = 1e-3f;
@@ -3283,3 +3286,42 @@ TEST(ONNX, GatherND_3)
32833286
}
32843287
}
32853288

3289+
TEST(ONNX, NonZero)
3290+
{
3291+
// test GatherND elements using batch size as first dim (bs=2)
3292+
std::vector<int8_t> input = {0,1,0, 1,1,0, 0,0,1, 0,1,1 }; // shape is (2x2x3)
3293+
// output is tensor shape { 3, number of non zeros}
3294+
std::vector<int32_t> correct_output = { 0,0,0,1,1,1 , 0,1,1,0,1,1 , 1,0,1,2,1,2 };
3295+
3296+
TMVA_SOFIE_NonZero::Session s("NonZero_FromONNX.dat");
3297+
3298+
auto output = s.infer(input.data());
3299+
3300+
// Checking output size
3301+
EXPECT_EQ(output.size(), correct_output.size());
3302+
// Checking output
3303+
for (size_t i = 0; i < output.size(); ++i) {
3304+
EXPECT_EQ(output[i] , correct_output[i]);
3305+
}
3306+
}
3307+
3308+
TEST(ONNX, NonZero_Constant)
3309+
{
3310+
// test GatherND elements using batch size as first dim (bs=2)
3311+
//std::vector<int8_t> input = {0,1,0, 1,1,0, 0,0,1, 0,1,1 }; // shape is (2x2x3)
3312+
// output is tensor shape { 3, number of non zeros}
3313+
std::vector<int32_t> correct_output = { 0,0,0,1,1,1 , 0,1,1,0,1,1 , 1,0,1,2,1,2 };
3314+
3315+
TMVA_SOFIE_NonZero_Constant::Session s("NonZero_Constant_FromONNX.dat");
3316+
3317+
auto output = s.infer();
3318+
3319+
// Checking output size
3320+
EXPECT_EQ(output.size(), correct_output.size());
3321+
// Checking output
3322+
for (size_t i = 0; i < output.size(); ++i) {
3323+
EXPECT_EQ(output[i] , correct_output[i]);
3324+
}
3325+
}
3326+
3327+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
 onnx-example:Z
3+

4+
dataoutput"NonZero TestGraphZ
5+
data
6+

7+

8+

9+
b
10+
output
11+

12+

13+
B
171 Bytes
Binary file not shown.

tmva/sofie_parsers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser
7373
src/ParseEinsum.cxx
7474
src/ParseRandom.cxx
7575
src/ParseScatterElements.cxx
76+
src/ParseNonZero.cxx
7677
${PROTO_SRCS}
7778
LIBRARIES PUBLIC
7879
protobuf::libprotobuf

0 commit comments

Comments
 (0)