Skip to content

Commit a56f6e3

Browse files
committed
[tmva][sofie] Fix binary operators for dynamic tensors
After optmizing the binary operators to void broadcasting, the fix needed to be done also for the case of dynamic tensors Fix also an issue in semplifying the resulting output shape of Reshape when having dynamic tensor shapes
1 parent 4e09bb1 commit a56f6e3

File tree

2 files changed

+85
-35
lines changed

2 files changed

+85
-35
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -309,55 +309,73 @@ public:
309309
// we allocate here output vector
310310
out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && (" << fDimShapeA[i]
311311
<< " != 1 || " << fDimShapeB[i] << " != 1))\n";
312-
out << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator " << opName
312+
out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator " << opName
313313
<< "\");\n";
314314
}
315315
}
316+
out << SP << "}\n";
316317
}
317318

318-
auto stridesA = UTILITY::ComputeStrideFromShape(fShapeA);
319-
auto stridesB = UTILITY::ComputeStrideFromShape(fShapeB);
320-
auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
319+
auto stridesA = UTILITY::ComputeStrideFromShape(fDimShapeA);
320+
auto stridesB = UTILITY::ComputeStrideFromShape(fDimShapeB);
321+
auto stridesY = UTILITY::ComputeStrideFromShape(fDimShapeY);
321322

322323
std::string compute_idx_A, compute_idx_B, compute_idx_Y;
323-
if (std::all_of(fShapeA.begin(), fShapeA.end(), [](size_t x) { return x == 1; })) {
324+
if (std::all_of(fDimShapeA.begin(), fDimShapeA.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
324325
compute_idx_A = "0";
325326
} else {
326-
for (size_t i = 0; i < fShapeA.size(); ++i) {
327-
if (fShapeA[i] == 1)
327+
for (size_t i = 0; i < fDimShapeA.size(); ++i) {
328+
if (fDimShapeA[i].dim == 1 || fDimShapeA[i].GetVal() == "1")
328329
continue;
329-
compute_idx_A +=
330-
" idx_" + std::to_string(i + (fShapeY.size() - fShapeA.size())) + " * " + stridesA[i] + " +";
330+
compute_idx_A += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeA.size()));
331+
if (stridesA[i].GetVal() != "1")
332+
compute_idx_A += " * " + stridesA[i].GetVal();
333+
compute_idx_A += " + ";
331334
}
332-
compute_idx_A.pop_back();
335+
// remove last 3 character " + "
336+
for (int j = 0; j < 3; j++)
337+
compute_idx_A.pop_back();
333338
}
334-
if (std::all_of(fShapeB.begin(), fShapeB.end(), [](size_t x) { return x == 1; })) {
339+
if (std::all_of(fDimShapeB.begin(), fDimShapeB.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
335340
compute_idx_B = "0";
336341
} else {
337-
for (size_t i = 0; i < fShapeB.size(); ++i) {
338-
if (fShapeB[i] == 1)
342+
for (size_t i = 0; i < fDimShapeB.size(); ++i) {
343+
if (fDimShapeB[i].dim == 1 || fDimShapeB[i].GetVal() == "1")
339344
continue;
340-
compute_idx_B +=
341-
" idx_" + std::to_string(i + (fShapeY.size() - fShapeB.size())) + " * " + stridesB[i] + " +";
345+
compute_idx_B += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeB.size()));
346+
if (stridesB[i].GetVal() != "1")
347+
compute_idx_B += " * " + stridesB[i].GetVal();
348+
compute_idx_B += " + ";
342349
}
343-
compute_idx_B.pop_back();
350+
// remove last 3 character " + "
351+
for (int j = 0; j < 3; j++)
352+
compute_idx_B.pop_back();
344353
}
345-
for (size_t i = 0; i < fShapeY.size(); ++i) {
346-
if (fShapeY[i] != 1) {
347-
out << std::string(i + 1, ' ') << "for(size_t idx_" << i << "=0; idx_" << i << "<" << fShapeY[i]
354+
int nloop = 0;
355+
for (size_t i = 0; i < fDimShapeY.size(); ++i) {
356+
if (fDimShapeY[i].dim != 1 && fDimShapeY[i].GetVal() != "1") {
357+
nloop++;
358+
for (int j = 0; j < nloop; j++) out << SP;
359+
out << "for (size_t idx_" << i << " = 0; idx_" << i << " < " << fDimShapeY[i]
348360
<< "; ++idx_" << i << "){\n";
349-
compute_idx_Y += "idx_" + std::to_string(i) + "*" + stridesY[i] + "+";
361+
compute_idx_Y += "idx_" + std::to_string(i);
362+
if (stridesY[i].GetVal() != "1")
363+
compute_idx_Y += " * " + stridesY[i].GetVal();
364+
compute_idx_Y += " + ";
350365
}
351366
}
352-
compute_idx_Y.pop_back();
353-
out << SP << SP << "tensor_" << fNY << "[" << compute_idx_Y << "] = "
367+
// remove last 3 characters " + "
368+
for (int j = 0; j < 3; j++)
369+
compute_idx_Y.pop_back();
370+
for (int j = 0; j < nloop+1; j++) out << SP;
371+
out << "tensor_" << fNY << "[" << compute_idx_Y << "] = "
354372
<< BinaryOperatorTrait<T, Op>::Op("tensor_" + fNA + "[" + compute_idx_A + "]",
355373
"tensor_" + fNB + "[" + compute_idx_B + "]")
356374
<< " ;\n";
357-
for (size_t i = 0; i < fShapeY.size(); ++i) {
358-
if (fShapeY[i] != 1) {
359-
out << std::string(fShapeY.size() - i + 1, ' ') << "}\n";
360-
}
375+
376+
for (int i = nloop; i > 0; i--) {
377+
for (int j = 0; j < i; j++) out << SP;
378+
out << "}\n";
361379
}
362380
return out.str();
363381
}

tmva/sofie/inc/TMVA/ROperator_Reshape.hxx

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include "TMVA/RModel.hxx"
77

88
#include <cassert>
9+
#include <cctype>
910
#include <sstream>
11+
#include <algorithm>
1012

1113
namespace TMVA{
1214
namespace Experimental{
@@ -100,25 +102,55 @@ public:
100102
tmp.erase(tmp.begin() + i);
101103
auto tmp_length = ConvertDimShapeToLength(tmp);
102104
auto input_length = ConvertDimShapeToLength(input_shape);
105+
if (fVerbose)
106+
std::cout << "reshape- try simplifying " << ConvertDimShapeToString(input_shape) << " with length "
107+
<< input_length << " to " << tmp_length << std::endl;
108+
103109
if (IsInteger(tmp_length) && IsInteger(input_length))
104110
output_shape[i] = Dim{static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
105111
else {
106112
//we can try simplifying expression if tmp_length is integer and part of input_length
107113
// contains tmp_length
108114
bool canSimplify = false;
115+
std::vector <Dim> reduced_input;
109116
if (IsInteger(tmp_length)) {
110-
std::vector<Dim> reduced_input = input_shape;
111-
for (auto & s : input_shape) {
112-
if (s.GetVal() == tmp_length) {
113-
//erase value in the reduced_input vector
114-
auto itr = std::find(reduced_input.begin(), reduced_input.end(), s);
115-
reduced_input.erase(itr);
117+
118+
// try to tokenize with * the input length
119+
120+
std::stringstream ss(input_length);
121+
122+
std::string token;
123+
124+
// Tokenizing w.r.t. space '*'
125+
while(getline(ss, token, '*'))
126+
{
127+
// remove any whitespace
128+
token.erase(std::remove_if(token.begin(), token.end(),
129+
[](unsigned char x) { return std::isspace(x); }), token.end());
130+
if (token != tmp_length) {
131+
if (IsInteger(token)) {
132+
size_t il = static_cast<size_t>(std::stoi(input_length));
133+
size_t tl = static_cast<size_t>(std::stoi(tmp_length));
134+
if ((il % tl) == 0) {
135+
canSimplify = true;
136+
reduced_input.push_back(Dim{il / tl});
137+
}
138+
} else {
139+
reduced_input.push_back(Dim{token});
140+
}
141+
} else {
142+
// token is equal to tmp_length, can be not considered and is simplified
116143
canSimplify = true;
117-
break;
118144
}
119145
}
120-
if (canSimplify)
121-
output_shape[i] = Dim{std::string("(") + ConvertDimShapeToLength(reduced_input) + ")", static_cast<size_t>(-1)};
146+
}
147+
if (canSimplify) {
148+
// if length contains * we need to add some brackets
149+
std::string res_shape = ConvertDimShapeToLength(reduced_input);
150+
if (res_shape.find('*') != std::string::npos)
151+
output_shape[i] = Dim{std::string("(") + res_shape + ")", static_cast<size_t>(-1)};
152+
else
153+
output_shape[i] = Dim{res_shape};
122154
}
123155
if (!canSimplify)
124156
output_shape[i] = Dim{std::string("(") + input_length + " / (" + tmp_length + "))", static_cast<size_t>(-1)};

0 commit comments

Comments
 (0)