Skip to content

Commit 287a6bf

Browse files
sanjibansglmoneta
authored andcommitted
fix: clang-format for ROperator_BasicBinary
1 parent 6e617d6 commit 287a6bf

File tree

1 file changed

+107
-86
lines changed

1 file changed

+107
-86
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 107 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,59 @@
77

88
#include <sstream>
99

10-
namespace TMVA{
11-
namespace Experimental{
12-
namespace SOFIE{
10+
namespace TMVA {
11+
namespace Experimental {
12+
namespace SOFIE {
1313

14-
enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
14+
enum EBasicBinaryOperator {
15+
Add,
16+
Sub,
17+
Mul,
18+
Div,
19+
Pow
20+
};
1521

1622
template <typename T, EBasicBinaryOperator Op1>
1723
struct BinaryOperatorTrait {};
1824

1925
template <typename T>
2026
struct BinaryOperatorTrait<T, Add> {
2127
static const std::string Name() { return "Add"; }
22-
static std::string Op(const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
23-
static T Func(T t1, T t2) {return t1 + t2;}
28+
static std::string Op(const std::string &t1, const std::string t2) { return t1 + " + " + t2; }
29+
static T Func(T t1, T t2) { return t1 + t2; }
2430
};
2531

2632
template <typename T>
2733
struct BinaryOperatorTrait<T, Sub> {
2834
static const std::string Name() { return "Sub"; }
29-
static std::string Op(const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
30-
static T Func (T t1, T t2) { return t1 - t2;}
35+
static std::string Op(const std::string &t1, const std::string t2) { return t1 + " - " + t2; }
36+
static T Func(T t1, T t2) { return t1 - t2; }
3137
};
3238

3339
template <typename T>
3440
struct BinaryOperatorTrait<T, Mul> {
3541
static const std::string Name() { return "Mul"; }
36-
static std::string Op(const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
37-
static T Func (T t1, T t2) { return t1 * t2;}
42+
static std::string Op(const std::string &t1, const std::string t2) { return t1 + " * " + t2; }
43+
static T Func(T t1, T t2) { return t1 * t2; }
3844
};
3945

4046
template <typename T>
4147
struct BinaryOperatorTrait<T, Div> {
4248
static const std::string Name() { return "Div"; }
43-
static std::string Op(const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
44-
static T Func (T t1, T t2) { return t1/t2;}
49+
static std::string Op(const std::string &t1, const std::string t2) { return t1 + " / " + t2; }
50+
static T Func(T t1, T t2) { return t1 / t2; }
4551
};
4652

4753
template <typename T>
4854
struct BinaryOperatorTrait<T, Pow> {
4955
static const std::string Name() { return "Pow"; }
50-
static std::string Op(const std::string & t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
51-
static T Func (T t1, T t2) { return std::pow(t1,t2);}
56+
static std::string Op(const std::string &t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
57+
static T Func(T t1, T t2) { return std::pow(t1, t2); }
5258
};
5359

54-
template<typename T, EBasicBinaryOperator Op>
55-
class ROperator_BasicBinary final : public ROperator{
60+
template <typename T, EBasicBinaryOperator Op>
61+
class ROperator_BasicBinary final : public ROperator {
5662
private:
57-
5863
int fBroadcastFlag = 0;
5964
std::string fNA;
6065
std::string fNB;
@@ -71,28 +76,29 @@ private:
7176
std::vector<Dim> fDimShapeY;
7277

7378
public:
74-
ROperator_BasicBinary(){}
75-
ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY):
76-
fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)){
77-
fInputTensorNames = { fNA, fNB };
78-
fOutputTensorNames = { fNY };
79-
}
79+
ROperator_BasicBinary() {}
80+
ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY)
81+
: fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
82+
{
83+
fInputTensorNames = {fNA, fNB};
84+
fOutputTensorNames = {fNY};
85+
}
8086

8187
// type of output given input
82-
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
83-
return input;
84-
}
88+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
8589

8690
// shape of output tensors given input tensors
87-
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
91+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override
92+
{
8893
// assume now inputs have same shape (no broadcasting)
8994
auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
9095
return ret;
9196
}
92-
93-
void Initialize(RModel& model) override {
97+
98+
void Initialize(RModel &model) override
99+
{
94100
// input must be a graph input, or already initialized intermediate tensor
95-
if (!model.CheckIfTensorAlreadyExist(fNA)){
101+
if (!model.CheckIfTensorAlreadyExist(fNA)) {
96102
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
97103
}
98104
if (!model.CheckIfTensorAlreadyExist(fNB)) {
@@ -113,10 +119,12 @@ public:
113119
fShapeB = model.GetTensorShape(fNB);
114120
fDimShapeB = ConvertShapeToDim(fShapeB);
115121
}
116-
if (dynamicInputs & 1 && model.Verbose() )
117-
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNA << " is dynamic " << ConvertShapeToString(fDimShapeA) << " ";
122+
if (dynamicInputs & 1 && model.Verbose())
123+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNA << " is dynamic "
124+
<< ConvertShapeToString(fDimShapeA) << " ";
118125
if (dynamicInputs & 2 && model.Verbose())
119-
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNB << " is dynamic " << ConvertShapeToString(fDimShapeB) << " ";
126+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNB << " is dynamic "
127+
<< ConvertShapeToString(fDimShapeB) << " ";
120128
std::cout << std::endl;
121129
// check if need to broadcast at initialization time if shapes are known and different
122130
// (we could broadcast the tensor tensor to maximum values of dynamic shapes - to be done)
@@ -125,7 +133,7 @@ public:
125133
auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
126134
fBroadcastFlag = ret.first;
127135
fShapeY = ret.second;
128-
bool broadcast = ret.first > 0;
136+
bool broadcast = ret.first > 0;
129137
if (broadcast) {
130138
// Y is the common shape of A and B
131139
bool broadcastA = ret.first & 2;
@@ -186,16 +194,16 @@ public:
186194
model.SetNotWritableInitializedTensor(nameB);
187195
fIsOutputConstant = true;
188196
if (model.Verbose()) {
189-
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
190-
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY
191-
<< " " << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
197+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
198+
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " "
199+
<< ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
192200
}
193201
} else {
194202
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
195203
if (model.Verbose()) {
196-
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
197-
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY
198-
<< " " << ConvertShapeToString(fShapeY) << std::endl;
204+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
205+
<< " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " "
206+
<< ConvertShapeToString(fShapeY) << std::endl;
199207
}
200208
}
201209
// we convert non-dim shapes to Dim shapes
@@ -211,17 +219,18 @@ public:
211219
if (ret.first & 4) {
212220
// check if one of the parameter is an input dimension
213221
// define function to find this
214-
auto IsInputDimParam = [&](const std::string & p) {
222+
auto IsInputDimParam = [&](const std::string &p) {
215223
auto inputNames = model.GetInputTensorNames();
216-
for (auto & input : inputNames) {
217-
for (auto & i_s : model.GetDimTensorShape(input)) {
218-
if (i_s.isParam && i_s.param == p) return true;
224+
for (auto &input : inputNames) {
225+
for (auto &i_s : model.GetDimTensorShape(input)) {
226+
if (i_s.isParam && i_s.param == p)
227+
return true;
219228
}
220229
}
221230
return false;
222231
};
223232
for (size_t i = 0; i < fDimShapeY.size(); i++) {
224-
auto & s = fDimShapeY[i];
233+
auto &s = fDimShapeY[i];
225234
if (s.isParam && s.param.find("std::max") != std::string::npos) {
226235
if (IsInputDimParam(fDimShapeA[i].param)) {
227236
// case dim is 1 we indicate that the input parameter is equal to 1
@@ -238,7 +247,7 @@ public:
238247
}
239248
}
240249
}
241-
250+
242251
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fDimShapeY);
243252
if (model.Verbose()) {
244253
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << ConvertShapeToString(fDimShapeA) << " , "
@@ -247,22 +256,25 @@ public:
247256
}
248257
}
249258

250-
std::string GenerateInitCode() override {
259+
std::string GenerateInitCode() override
260+
{
251261
std::stringstream out;
252262
return out.str();
253263
}
254264

255-
std::string Generate(std::string opName) override {
265+
std::string Generate(std::string opName) override
266+
{
256267

257-
if (fIsOutputConstant) return "";
268+
if (fIsOutputConstant)
269+
return "";
258270

259271
opName = "op_" + opName;
260272

261273
if (fDimShapeY.empty()) {
262274
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
263275
}
264276
std::stringstream out;
265-
out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
277+
out << SP << "\n//------ " << BinaryOperatorTrait<T, Op>::Name() << "\n";
266278
auto length = ConvertDimShapeToLength(fDimShapeY);
267279
std::string typeName = TensorType<T>::Name();
268280

@@ -273,82 +285,91 @@ public:
273285
auto lengthB = ConvertDimShapeToLength(fDimShapeB);
274286
out << SP << "if (" << lengthA << "!=" << lengthB << ") {\n";
275287
// check if A->B or B->A
276-
//bool broadcastable = true;
288+
// bool broadcastable = true;
277289
for (size_t i = 0; i < fDimShapeY.size(); i++) {
278-
if (fBroadcastFlag & 5 && fDimShapeY[i] == fDimShapeA[i] && fDimShapeA[i].dim > 1 && fDimShapeB[i].isParam) {
290+
if (fBroadcastFlag & 5 && fDimShapeY[i] == fDimShapeA[i] && fDimShapeA[i].dim > 1 &&
291+
fDimShapeB[i].isParam) {
279292
// B->A B[i] needs to be 1
280293
out << SP << SP << "if (" << fDimShapeB[i] << "!= 1)\n";
281294
out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast B->A in operator "
282-
<< opName << "\");\n";
295+
<< opName << "\");\n";
283296
}
284-
if (fBroadcastFlag & 6 && fDimShapeY[i] == fDimShapeB[i] && fDimShapeB[i].dim > 1 && fDimShapeA[i].isParam) {
285-
//A-> B A[i] needs to be 1
297+
if (fBroadcastFlag & 6 && fDimShapeY[i] == fDimShapeB[i] && fDimShapeB[i].dim > 1 &&
298+
fDimShapeA[i].isParam) {
299+
// A-> B A[i] needs to be 1
286300
out << SP << SP << "if (" << fDimShapeA[i] << "!= 1)\n";
287301
out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast A->B in operator "
288-
<< opName << "\");\n";
289-
}
290-
else if (fDimShapeA[i].isParam && fDimShapeB[i].isParam) {
302+
<< opName << "\");\n";
303+
} else if (fDimShapeA[i].isParam && fDimShapeB[i].isParam) {
291304
// both shapes are parametric and we broadcast to maximum
292305
// we allocate here output vector
293-
out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && ("
294-
<< fDimShapeA[i] << " != 1 || " << fDimShapeB[i] << " != 1))\n";
295-
out << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator "
296-
<< opName << "\");\n";
306+
out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && (" << fDimShapeA[i]
307+
<< " != 1 || " << fDimShapeB[i] << " != 1))\n";
308+
out << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator " << opName
309+
<< "\");\n";
297310
}
298311
}
299312
}
300-
313+
301314
auto stridesA = UTILITY::ComputeStrideFromShape(fShapeA);
302315
auto stridesB = UTILITY::ComputeStrideFromShape(fShapeB);
303316
auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
304317

305318
std::string compute_idx_A, compute_idx_B, compute_idx_Y;
306-
if (std::all_of(fShapeA.begin(), fShapeA.end(), [](size_t x) { return x == 1; })){
319+
if (std::all_of(fShapeA.begin(), fShapeA.end(), [](size_t x) { return x == 1; })) {
307320
compute_idx_A = "0";
308321
} else {
309-
for(size_t i = 0; i<fShapeA.size(); ++i){
310-
if(fShapeA[i]==1) continue;
311-
compute_idx_A += " idx_"+fNY+std::to_string(i+(fShapeY.size()-fShapeA.size()))+" * "+stridesA[i]+" +";
322+
for (size_t i = 0; i < fShapeA.size(); ++i) {
323+
if (fShapeA[i] == 1)
324+
continue;
325+
compute_idx_A +=
326+
" idx_" + fNY + std::to_string(i + (fShapeY.size() - fShapeA.size())) + " * " + stridesA[i] + " +";
312327
}
313328
compute_idx_A.pop_back();
314329
}
315-
if (std::all_of(fShapeB.begin(), fShapeB.end(), [](size_t x) { return x == 1; })){
330+
if (std::all_of(fShapeB.begin(), fShapeB.end(), [](size_t x) { return x == 1; })) {
316331
compute_idx_B = "0";
317332
} else {
318-
for(size_t i = 0; i<fShapeB.size(); ++i){
319-
if(fShapeB[i]==1) continue;
320-
compute_idx_B += " idx_"+fNY+std::to_string(i+(fShapeY.size()-fShapeB.size()))+" * "+stridesB[i]+" +";
333+
for (size_t i = 0; i < fShapeB.size(); ++i) {
334+
if (fShapeB[i] == 1)
335+
continue;
336+
compute_idx_B +=
337+
" idx_" + fNY + std::to_string(i + (fShapeY.size() - fShapeB.size())) + " * " + stridesB[i] + " +";
321338
}
322339
compute_idx_B.pop_back();
323340
}
324-
for(size_t i=0; i<fShapeY.size(); ++i){
325-
if(fShapeY[i]!=1){
326-
out<<std::string(i + 1, ' ')<<"for(size_t idx_"<<fNY<<i<<"=0; idx_"<<fNY<<i<<"<"<<fShapeY[i]<<"; ++idx_"<<fNY<<i<<"){\n";
327-
compute_idx_Y += "idx_"+fNY+std::to_string(i)+"*"+stridesY[i]+"+";
341+
for (size_t i = 0; i < fShapeY.size(); ++i) {
342+
if (fShapeY[i] != 1) {
343+
out << std::string(i + 1, ' ') << "for(size_t idx_" << fNY << i << "=0; idx_" << fNY << i << "<"
344+
<< fShapeY[i] << "; ++idx_" << fNY << i << "){\n";
345+
compute_idx_Y += "idx_" + fNY + std::to_string(i) + "*" + stridesY[i] + "+";
328346
}
329347
}
330348
compute_idx_Y.pop_back();
331-
out << SP << SP << "tensor_" << fNY <<"["<<compute_idx_Y<<"] = "<<BinaryOperatorTrait<T,Op>::Op("tensor_"+ fNA + "["+compute_idx_A+"]", "tensor_"+ fNB + "["+compute_idx_B+"]")<<" ;\n";
332-
for(size_t i=0; i<fShapeY.size(); ++i){
333-
if(fShapeY[i]!=1){
334-
out<<std::string(fShapeY.size()-i+1, ' ')<<"}\n";
349+
out << SP << SP << "tensor_" << fNY << "[" << compute_idx_Y << "] = "
350+
<< BinaryOperatorTrait<T, Op>::Op("tensor_" + fNA + "[" + compute_idx_A + "]",
351+
"tensor_" + fNB + "[" + compute_idx_B + "]")
352+
<< " ;\n";
353+
for (size_t i = 0; i < fShapeY.size(); ++i) {
354+
if (fShapeY[i] != 1) {
355+
out << std::string(fShapeY.size() - i + 1, ' ') << "}\n";
335356
}
336357
}
337358
return out.str();
338359
}
339360

340-
std::vector<std::string> GetStdLibs() override {
361+
std::vector<std::string> GetStdLibs() override
362+
{
341363
if (Op == EBasicBinaryOperator::Pow) {
342-
return { std::string("cmath") };
364+
return {std::string("cmath")};
343365
} else {
344366
return {};
345367
}
346368
}
347369
};
348370

349-
}//SOFIE
350-
}//Experimental
351-
}//TMVA
352-
371+
} // namespace SOFIE
372+
} // namespace Experimental
373+
} // namespace TMVA
353374

354-
#endif //TMVA_SOFIE_ROperator_BasicBinary
375+
#endif // TMVA_SOFIE_ROperator_BasicBinary

0 commit comments

Comments
 (0)