Skip to content

Commit 1c6fd66

Browse files
committed
[tmva][sofie] Add support for Dynamic tensors in more operators
Include now also BatchNormalization, COnstantOfShape, Gather , Reshape and Sigmoid Fix the broadcasting of dynamic tensors in BasicBinary
1 parent 524e71d commit 1c6fd66

File tree

9 files changed

+222
-96
lines changed

9 files changed

+222
-96
lines changed

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ template<typename T, EBasicBinaryOperator Op>
5555
class ROperator_BasicBinary final : public ROperator{
5656
private:
5757

58+
int fBroadcastFlag = 0;
5859
std::string fNA;
5960
std::string fNB;
6061
std::string fNBroadcastedA;
@@ -114,12 +115,14 @@ public:
114115
// case of known shapes
115116
if (!fShapeA.empty() && !fShapeB.empty()) {
116117
auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
118+
fBroadcastFlag = ret.first;
117119
fShapeY = ret.second;
120+
std::cout << BinaryOperatorTrait<T, Op>::Name() << "checking for defined shapes " << fBroadcastFlag << " " << ConvertShapeToString(fShapeY) << std::endl;
118121
bool broadcast = ret.first > 0;
119122
if (broadcast) {
120123
// Y is the common shape of A and B
121-
bool broadcastA = ret.first > 1;
122-
bool broadcastB = ret.first == 1 || ret.first == 3;
124+
bool broadcastA = ret.first & 2;
125+
bool broadcastB = ret.first & 1;
123126
// Broadcast A to Y
124127
if (broadcastA) {
125128
fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
@@ -191,17 +194,28 @@ public:
191194
else {
192195
// case A or B have dynamic shapes. We need to broadcast if shape are not same
193196
auto ret = UTILITY::MultidirectionalBroadcastShape(fDimShapeA, fDimShapeB);
197+
fBroadcastFlag = ret.first;
194198
fDimShapeY = ret.second;
195-
if (ret.first > 1) {
199+
std::cout << BinaryOperatorTrait<T, Op>::Name() << " : checking for Dim shapes " << fBroadcastFlag << " " << ConvertShapeToString(fDimShapeY) << std::endl;
200+
if (ret.first & 2) {
196201
// case we broadcast A
197202
fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
198203
model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fDimShapeY);
199204
}
200-
if (ret.first == 1 || ret.first == 3) {
205+
if (ret.first & 1) {
201206
// case we broadcast B
202207
fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
203208
model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fDimShapeY);
204209
}
210+
// case of all parametric shapes and we know only at run time
211+
// we don't add in this case an intermediate tensor for broadcasting
212+
// if (ret.first == 4) {
213+
// for (auto & d : fDimShapeY) {
214+
// if (d.isParam && d.param.find("broadcast") != std::string::npos) {
215+
// d.param += fNY;
216+
// }
217+
// }
218+
// }
205219
// add output tensor
206220
model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fDimShapeY);
207221
}
@@ -212,11 +226,11 @@ public:
212226
return out.str();
213227
}
214228

215-
std::string Generate(std::string OpName) override {
229+
std::string Generate(std::string opName) override {
216230

217231
if (fIsOutputConstant) return "";
218232

219-
OpName = "op_" + OpName;
233+
opName = "op_" + opName;
220234

221235
if (fDimShapeY.empty()) {
222236
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
@@ -225,21 +239,55 @@ public:
225239
out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
226240
auto length = ConvertDimShapeToLength(fDimShapeY);
227241
std::string typeName = TensorType<T>::Name();
242+
// we need to check if we can broadcast (case flag has bit 4 set)
243+
if (fBroadcastFlag & 4) {
244+
// need to check if shapes are the same
245+
auto lengthA = ConvertDimShapeToLength(fDimShapeA);
246+
auto lengthB = ConvertDimShapeToLength(fDimShapeB);
247+
out << SP << "if (" << lengthA << "!=" << lengthB << ") {\n";
248+
// check if A->B or B->A
249+
//bool broadcastable = true;
250+
for (size_t i = 0; i < fDimShapeY.size(); i++) {
251+
if (fBroadcastFlag & 5 && fDimShapeY[i] == fDimShapeA[i] && fDimShapeA[i].dim > 1 && fDimShapeB[i].isParam) {
252+
// B->A B[i] needs to be 1
253+
out << SP << SP << "if (" << fDimShapeB[i] << "!= 1)\n";
254+
out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast B->A in operator "
255+
<< opName << "\");\n";
256+
}
257+
if (fBroadcastFlag & 6 && fDimShapeY[i] == fDimShapeB[i] && fDimShapeB[i].dim > 1 && fDimShapeA[i].isParam) {
258+
//A-> B A[i] needs to be 1
259+
out << SP << SP << "if (" << fDimShapeA[i] << "!= 1)\n";
260+
out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast A->B in operator "
261+
<< opName << "\");\n";
262+
}
263+
else if (fDimShapeA[i].isParam && fDimShapeB[i].isParam) {
264+
// both shapes are parametric and we broadcast to maximum
265+
// we allocate here output vector
266+
out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && ("
267+
<< fDimShapeA[i] << " != 1 || " << fDimShapeB[i] << " != 1))\n";
268+
out << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator "
269+
<< opName << "\");\n";
270+
}
271+
}
272+
} else {
273+
out << SP << "{\n";
274+
}
228275
// Broadcast A if it's uninitialized
229276
// use broadcasting function where we pass an already allocated tensor to minimize memory allocations
230277
if (!fNBroadcastedA.empty()) {
231-
out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
232-
out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", "
278+
out << SP << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
279+
out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", "
233280
<< ConvertDimShapeToString(fDimShapeA) << ", " << ConvertDimShapeToString(fDimShapeY)
234281
<< ", fTensor_" << fNBroadcastedA << ");\n";
235282
}
236283
// Broadcast B if it's uninitialized
237284
if (!fNBroadcastedB.empty()) {
238-
out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
239-
out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", "
285+
out << SP << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
286+
out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", "
240287
<< ConvertDimShapeToString(fDimShapeB) << ", " << ConvertDimShapeToString(fDimShapeY)
241288
<< ", fTensor_" << fNBroadcastedB << ");\n";
242289
}
290+
out << SP << "}\n"; // end if on broadcasting
243291
const std::string& nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
244292
const std::string& nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
245293
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";

tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ private:
3232
std::string fNY;
3333
EActivationType fActivation;
3434

35-
std::vector<size_t> fShapeX;
35+
std::vector<Dim> fShapeX;
3636
std::vector<size_t> fShapeScale;
3737
std::vector<size_t> fShapeB;
3838
std::vector<size_t> fShapeMean;
3939
std::vector<size_t> fShapeVar;
40-
std::vector<size_t> fShapeY;
40+
std::vector<Dim> fShapeY;
4141

4242
std::string fType;
4343

@@ -109,7 +109,7 @@ public:
109109
std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNVar + " fnv is not found in model");
110110
}
111111

112-
fShapeX = model.GetTensorShape(fNX);
112+
fShapeX = model.GetDimTensorShape(fNX);
113113

114114
if (fShapeX.size() < 2 || fShapeX.size() > 4) {
115115
throw
@@ -123,16 +123,17 @@ public:
123123
fShapeY = fShapeX;
124124
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
125125

126-
if (fShapeB.size() == 1) {
126+
if (fShapeB.size() == 1 && !model.IsDynamicTensor(fNX)) {
127+
auto shapeX = model.GetTensorShape(fNX);
127128
// Broadcast scale, bias, input_mean and input_var to shape_X
128129
auto original_B = model.GetInitializedTensorData(fNB);
129130
auto original_S = model.GetInitializedTensorData(fNScale);
130131
auto original_M = model.GetInitializedTensorData(fNMean);
131132
auto original_V = model.GetInitializedTensorData(fNVar);
132-
size_t batchSize = fShapeX[0];
133-
size_t channels = fShapeX[1];
134-
size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
135-
size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
133+
size_t batchSize = shapeX[0];
134+
size_t channels = shapeX[1];
135+
size_t height = (shapeX.size() > 2) ? shapeX[2] : 1;
136+
size_t width = (shapeX.size() > 3) ? shapeX[3] : 1;
136137
size_t n = batchSize * channels * height * width;
137138
if (fType == "float") {
138139
float *original_bias = static_cast<float *>(original_B.get());
@@ -181,6 +182,8 @@ public:
181182
fShapeMean = model.GetTensorShape(fNMean);
182183
fShapeVar = model.GetTensorShape(fNVar);
183184
}
185+
} else {
186+
// we need to broadcast at run time
184187
}
185188
}
186189

@@ -192,15 +195,15 @@ public:
192195

193196
std::stringstream out;
194197
//// Batch Norm op
195-
size_t batchSize = fShapeX[0];
196-
size_t channels = fShapeX[1];
197-
size_t height = (fShapeX.size() > 2) ? fShapeX[2] : 1;
198-
size_t width = (fShapeX.size() > 3) ? fShapeX[3] : 1;
199-
size_t n = batchSize * channels * height * width;
198+
std::string batchSize = fShapeX[0].GetVal();
199+
std::string channels = fShapeX[1].GetVal();
200+
std::string height = (fShapeX.size() > 2) ? fShapeX[2].GetVal() : "1";
201+
std::string width = (fShapeX.size() > 3) ? fShapeX[3].GetVal() : "1";
202+
auto n = ConvertDimShapeToLength(fShapeX);
200203

201204
//// copy X into Y
202205
out << "\n\n//---- BatchNorm\n";
203-
out << SP << "constexpr int " << OpName << "_N =" << batchSize * channels * height * width << ";\n";
206+
out << SP << "constexpr int " << OpName << "_N =" << n << ";\n";
204207
out << SP << "constexpr int "<<OpName<< "_incx = 1;\n";
205208
out << SP << "constexpr int "<<OpName<< "_incy = 1;\n";
206209
out << SP << "BLAS::scopy_(&" << OpName << "_N, " << "tensor_" << fNX << ", &" << OpName << "_incx," << "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
@@ -222,7 +225,7 @@ public:
222225
<< "tensor_" << fNY << ", &" << OpName << "_incy);\n\n";
223226

224227
if(fActivation == EActivationType::RELU){
225-
out << SP << "for (int id = 0; id < " << ConvertShapeToLength(fShapeY) << " ; id++){\n";
228+
out << SP << "for (int id = 0; id < " << n << " ; id++){\n";
226229
out << SP << SP << "tensor_" << fNY << "[id] = ((tensor_" << fNY << "[id] > 0 )? tensor_" << fNY << "[id] : 0);\n";
227230
out << SP << "}\n";
228231
}

tmva/sofie/inc/TMVA/ROperator_Constant.hxx

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ private:
2020
std::string fNX;
2121
std::string fNY;
2222
std::vector<size_t> fShape;
23+
std::vector<Dim> fDimShape;
24+
std::vector<Dim> fDimOutputShape;
2325
std::vector<T> fValues;
2426
std::string fAttrType;
2527
bool fIsConstantOfShape = false;
@@ -58,28 +60,47 @@ public:
5860
}
5961
// get output shape from input values:
6062
// can work only if input is a constant or initialized tensor (or dynamic one)
61-
auto dptr = model.GetInitializedTensorData(fNX);
62-
auto input_tensor = static_cast<int64_t *>(dptr.get());
63-
auto input_shape = model.GetTensorShape(fNX);
64-
if (input_shape.size() > 1 )
65-
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
66-
if (input_tensor != nullptr && !input_shape.empty()) {
67-
fShape = std::vector<size_t> (input_shape[0]);
68-
for (size_t i = 0; i < fShape.size(); i++)
69-
fShape[i] = input_tensor[i];
70-
} else
71-
fShape = {1}; // scalar case
72-
73-
length = ConvertShapeToLength(fShape);
74-
if (fValues.size() != 1)
75-
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op value Tensor has invalid size " + std::to_string(fValues.size()));
76-
77-
T value = fValues[0];
78-
fValues = std::vector<T>(length, value);
63+
if (model.IsInitializedTensor(fNX) || model.IsConstantTensor(fNX)) {
64+
fIsOutputConstant = true;
65+
auto dptr = model.GetInitializedTensorData(fNX);
66+
auto input_tensor = static_cast<int64_t *>(dptr.get());
67+
auto input_shape = model.GetTensorShape(fNX);
68+
if (input_shape.size() > 1 )
69+
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
70+
if (input_tensor != nullptr && !input_shape.empty()) {
71+
fShape = std::vector<size_t> (input_shape[0]);
72+
for (size_t i = 0; i < fShape.size(); i++)
73+
fShape[i] = input_tensor[i];
74+
} else
75+
fShape = {1}; // scalar case
76+
77+
length = ConvertShapeToLength(fShape);
78+
if (fValues.size() != 1)
79+
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op value Tensor has invalid size " + std::to_string(fValues.size()));
80+
81+
T value = fValues[0];
82+
fValues = std::vector<T>(length, value);
83+
}
84+
else {
85+
// case of non constant tensors- we need to do at run time
86+
fDimShape = model.GetDimTensorShape(fNX);
87+
if (fDimShape.size() > 1 )
88+
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
89+
if (!fDimShape[0].isParam) {
90+
fDimOutputShape.resize(fDimShape[0].dim);
91+
for (size_t i = 0; i < fDimShape[0].dim; i++) {
92+
fDimOutputShape[i] = Dim{ std::string("s_") + fNY + "_" + std::to_string(i)};
93+
}
94+
}
95+
else {
96+
throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has not defied shape");
97+
}
98+
}
7999

80100
} else {
81101
// case of constant operator
82102
// in case of standard constant the shape is provided as input
103+
fIsOutputConstant = true;
83104
length = ConvertShapeToLength(fShape);
84105
if (length != fValues.size())
85106
throw std::runtime_error("TMVA SOFIE Constant Op has invalid shape : " + ConvertShapeToString(fShape) +
@@ -90,18 +111,41 @@ public:
90111
// but keep its initialization in the generated code. The values might also be needed in initializing the
91112
// following operators using as input Constant or ConstantOfShape
92113
// resize fValues to shape length
93-
model.AddConstantTensor(fNY, fShape, fValues);
94-
if (model.Verbose()) {
95-
std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
96-
<< " and values [";
97-
for (auto v : fValues) std::cout << " " << v;
98-
std::cout << "]" << std::endl;
114+
if (fIsOutputConstant) {
115+
model.AddConstantTensor(fNY, fShape, fValues);
116+
if (model.Verbose()) {
117+
std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
118+
<< " and values [";
119+
for (auto v : fValues) std::cout << " " << v;
120+
std::cout << "]" << std::endl;
121+
}
122+
} else {
123+
model.AddIntermediateTensor(fNY, ConvertStringToType(TensorType<T>::Name()), fDimOutputShape);
99124
}
100125
}
101126

102-
std::string Generate(std::string /* OpName */) override {
127+
std::string Generate(std::string opName) override {
103128
// no code to generate here. Tensor are defined in Session constructor
104-
return "//---------------------------------------\n";
129+
if (fIsOutputConstant) {
130+
if (fNX.empty())
131+
return "// ---- Constant (no-op) \n";
132+
else
133+
return "// ---- ConstantOfShape (no-op) \n";
134+
}
135+
// Only ConstantOfShape might require generation code
136+
// generate constant tensor according to input
137+
std::stringstream out;
138+
out << "\n//--------- ConstantOfShape " << opName << "\n";
139+
// set shape values
140+
for (size_t i = 0; i < fDimOutputShape.size(); i++) {
141+
out << SP << "size_t " << fDimOutputShape[i].param << " = " << "tensor_" << fNX << "[" << i << "];\n";
142+
}
143+
auto length = ConvertDimShapeToLength(fDimOutputShape);
144+
// vector is already allocated- fill with values
145+
out << SP << "if (" << length << " > fTensor_" << fNY << ".size())\n";
146+
out << SP << SP << "fTensor_" << fNY << ".resize(" << length << ");\n";
147+
out << SP << "std::fill(fTensor_" << fNY << ".begin(), fTensor_" << fNY << ".end(), " << fValues[0] << ");\n";
148+
return out.str();
105149
}
106150
};
107151

tmva/sofie/inc/TMVA/ROperator_Gather.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ public:
138138
// Indices of shape q
139139
size_t q = fShapeIndices.size();
140140
// Strides
141+
std::cout << "shapes of Gather " << ConvertShapeToString(fShapeX) << " " <<
142+
ConvertShapeToString(fShapeY) << " " << ConvertShapeToString(fShapeIndices) << std::endl;
141143
auto stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
142144
auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
143145
auto stridesIndices = UTILITY::ComputeStrideFromShape(fShapeIndices);

0 commit comments

Comments
 (0)