Skip to content

Commit f88344e

Browse files
committed
[tmva][sofie] Extend dynamic support for dynamic operators
Add suppot now for Conv, Gather, Comparison, Slice and ScatterElements
1 parent e308acb commit f88344e

File tree

6 files changed

+368
-206
lines changed

6 files changed

+368
-206
lines changed

tmva/sofie/inc/TMVA/ROperator_Comparision.hxx

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ private:
6262
std::string fNY;
6363
std::vector<size_t> fShapeX1;
6464
std::vector<size_t> fShapeX2;
65+
std::vector<Dim> fDimShapeX1;
66+
std::vector<Dim> fDimShapeX2;
6567
std::vector<size_t> fShapeY;
6668
std::string fNBroadcastedX1;
6769
std::string fNBroadcastedX2;
@@ -75,7 +77,7 @@ public:
7577
ROperator_Comparision(const std::string & nameX1, const std::string & nameX2, const std::string & nameY):
7678
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){
7779
fInputTensorNames = { fNX1, fNX2 };
78-
80+
7981
// output will be a boolean vector so should not be considered for memory optimized pool
8082
fOutputTensorNames = { fNY };
8183
}
@@ -99,8 +101,18 @@ public:
99101
if (!model.CheckIfTensorAlreadyExist(fNX2)) {
100102
throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
101103
}
102-
fShapeX1 = model.GetTensorShape(fNX1);
103-
fShapeX2 = model.GetTensorShape(fNX2);
104+
if (model.IsDynamicTensor(fNX1))
105+
fDimShapeX1 = model.GetDynamicTensorShape(fNX1);
106+
else {
107+
fShapeX1 = model.GetTensorShape(fNX1);
108+
fDimShapeX1 = ConvertShapeToDim(fShapeX1);
109+
}
110+
if (model.IsDynamicTensor(fNX2))
111+
fDimShapeX2 = model.GetDynamicTensorShape(fNX2);
112+
else {
113+
fShapeX2 = model.GetTensorShape(fNX2);
114+
fDimShapeX2 = ConvertShapeToDim(fShapeX2);
115+
}
104116
fTensorType1 = model.GetTensorType(fNX1);
105117
fTensorType2 = model.GetTensorType(fNX2);
106118
bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);

tmva/sofie/inc/TMVA/ROperator_Conv.hxx

Lines changed: 99 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ private:
3636
std::string convK;
3737
std::string imcol;
3838

39-
std::vector<size_t> fShapeX;
39+
std::vector<Dim> fShapeX;
4040
std::vector<size_t> fShapeW;
4141
std::vector<size_t> fShapeB;
42-
std::vector<size_t> fShapeY;
42+
std::vector<Dim> fShapeY;
4343

4444
std::string fType;
4545

@@ -93,29 +93,31 @@ public:
9393
}
9494

9595
// function returning output shape given input
96-
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
96+
std::vector<Dim> DoShapeInference(const std::vector<Dim> & input, const std::vector<size_t> & weight) {
9797
// shape of convolution input has to be (according to ONNX): N x C x H x W
9898
// Where N : batch size, C : input channels, H : input height, W : input width
9999

100-
if (input.size() > 3 ) {
101-
throw
102-
std::runtime_error("TMVA SOFIE Conv Op Shape inference need 2 or 3 input tensors");
100+
if (input.size() -2 != fDim) {
101+
throw std::runtime_error("TMVA SOFIE Conv Op Shape inference - invalid input ");
103102
}
104-
for(size_t i = 0; i < input.size(); i++) {
105-
if (input[i].size() -2 != fDim) {
106-
throw
107-
std::runtime_error("TMVA SOFIE Conv Op Shape inference - invalid inputs ");
108-
}
103+
if (weight.size() -2 != fDim) {
104+
throw std::runtime_error("TMVA SOFIE Conv Op Shape inference - invalid weights ");
105+
}
106+
if (fAttrGroup == 0 && input[1].isParam)
107+
throw std::runtime_error("TMVA SOFIE Conv - param shapes not supported without group attr");
108+
if (fAttrKernelShape.empty()) {
109+
if (input[2].isParam || (fDim > 1 && input[3].isParam) || (fDim > 2 && input[4].isParam))
110+
throw std::runtime_error("TMVA SOFIE Conv - param shapes not supported without kernel attr");
109111
}
110112

111113
if (fAttrGroup == 0) {
112-
fAttrGroup = input[0][1] / input[1][1];
114+
fAttrGroup = input[1].dim / weight[1];
113115
}
114116

115117
// kernel shape
116-
size_t k1 = ((fAttrKernelShape.empty())? input[1][2] : fAttrKernelShape[0]);
117-
size_t k2 = (fDim > 1) ? ((fAttrKernelShape.empty()) ? input[1][3] : fAttrKernelShape[1]) : 1;
118-
size_t k3 = (fDim > 2) ? ((fAttrKernelShape.empty()) ? input[1][4] : fAttrKernelShape[2]) : 1;
118+
size_t k1 = ((fAttrKernelShape.empty())? weight[2] : fAttrKernelShape[0]);
119+
size_t k2 = (fDim > 1) ? ((fAttrKernelShape.empty()) ? weight[3] : fAttrKernelShape[1]) : 1;
120+
size_t k3 = (fDim > 2) ? ((fAttrKernelShape.empty()) ? weight[4] : fAttrKernelShape[2]) : 1;
119121

120122

121123
size_t i1 = (fDim > 1) ? ((fDim > 2) ? 3 : 2) : 1;
@@ -171,33 +173,62 @@ public:
171173
fAttrStrides.resize(3, 1);
172174

173175

174-
size_t input1 = input[0][2];
175-
size_t input2 = (fDim > 1) ? input[0][3] : 1;
176-
size_t input3 = (fDim > 2) ? input[0][4] : 1;
176+
Dim input1 = input[2];
177+
Dim input2 = (fDim > 1) ? input[3] : Dim{1};
178+
Dim input3 = (fDim > 2) ? input[4] : Dim{1};
177179

178180
size_t pad1 = fAttrPads[0] + fAttrPads[i1];
179-
size_t output1 = (input1 + pad1 - fAttrKernelShape[0]) / fAttrStrides[0] + 1;
180181

181-
size_t batch_size = input[0][0]; // first element in input tensor
182-
size_t output_channels = input[1][0]; // first element in weight tensor
182+
// function to get output dimension of convolution given input
183+
184+
auto computeOutput = [&](Dim inputDim, size_t kernel, size_t pad, size_t stride) {
185+
if (!inputDim.isParam) {
186+
size_t outSize = (inputDim.dim + pad - kernel) / stride + 1;
187+
return Dim{outSize};
188+
} else {
189+
if (stride == 1){
190+
if ((pad - kernel + 1) == 0 )
191+
// output is same as input
192+
return inputDim;
193+
else {
194+
int64_t v = pad - kernel + 1;
195+
std::string outStr = "(" + inputDim.param + "+" + std::to_string(v) + ")";
196+
return Dim{ outStr, static_cast<size_t>(-1)};
197+
}
198+
} else { // general case (stride not 1)
199+
int64_t v = pad - kernel;
200+
std::string outStr = "((" + inputDim.param + "+" + std::to_string(v) + ")/"
201+
+ std::to_string(stride) + "1)";
202+
return Dim{ outStr, static_cast<size_t>(-1)};
203+
}
204+
}
205+
std::runtime_error("TMVA SOFIE Conv Op - invalid values");
206+
return Dim{};
207+
};
208+
209+
Dim output1 = computeOutput(input1, fAttrKernelShape[0], pad1, fAttrStrides[0]);
183210

184-
std::vector<std::vector<size_t>> ret({{ batch_size, output_channels, output1 }});
211+
Dim batch_size = input[0]; // first element in input tensor
212+
Dim output_channels = Dim{weight[0]}; // first element in weight tensor
213+
214+
std::vector<Dim> ret({ batch_size, output_channels, output1 });
185215

186216
if (fDim == 1)
187217
return ret;
188218

189219
size_t pad2 = fAttrPads[1] + fAttrPads[i2];
190-
size_t output2 = (input2 + pad2 - fAttrKernelShape[1]) / fAttrStrides[1] + 1;
220+
Dim output2 = computeOutput(input2, fAttrKernelShape[1], pad2, fAttrStrides[1]);
221+
191222
// output is N x M x OH x OW
192-
ret[0].push_back(output2);
223+
ret.push_back(output2);
193224
if (fDim == 2)
194225
return ret;
195226

196227
size_t pad3 = fAttrPads[2] + fAttrPads[i3];
197-
size_t output3 = (input3 + pad3 - fAttrKernelShape[2] ) / fAttrStrides[2] + 1;
228+
Dim output3 = computeOutput(input3, fAttrKernelShape[2], pad3, fAttrStrides[2]);
198229

199230
// output is N x M x OH x OW x OD
200-
ret[0].push_back(output3);
231+
ret.push_back(output3);
201232
return ret;
202233
}
203234

@@ -207,7 +238,7 @@ public:
207238
throw
208239
std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNX + " is not found in model");
209240
}
210-
fShapeX = model.GetTensorShape(fNX);
241+
fShapeX = model.GetDimTensorShape(fNX);
211242
if (fShapeX.size() < 3 || fShapeX.size() > 5) {
212243
std::cout << fNX << " : " << ConvertShapeToString(fShapeX) << std::endl;
213244
throw
@@ -223,24 +254,25 @@ public:
223254
std::cout << fNW << " : " << ConvertShapeToString(fShapeW) << std::endl;
224255
throw std::runtime_error("TMVA SOFIE Conv Op input weight tensor" + fNW + " is not of 3,4 or 5 dimensions");
225256
}
226-
fShapeY = ShapeInference({fShapeX, fShapeW})[0];
257+
fShapeY = DoShapeInference(fShapeX, fShapeW);
227258
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
228259
if (fNB != "") {
229260
if (!model.CheckIfTensorAlreadyExist(fNB)) {
230261
throw
231262
std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNB + " is not found in model");
232263
}
233264
fShapeB = model.GetTensorShape(fNB);
234-
std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
235-
bool broadcast_needed = !UTILITY::AreSameShape(fShapeB, targetShape);
265+
std::vector<Dim> targetShape(fShapeY.begin() + 1, fShapeY.end());
266+
auto shapeDimB = model.GetDimTensorShape(fNB);
267+
bool broadcast_needed = !UTILITY::AreSameShape(shapeDimB, targetShape);
236268
if (broadcast_needed) {
237269
auto original_data = model.GetInitializedTensorData(fNB);
238270
// make bias shape equal to Y shape by adding 1
239271
if (fShapeB.size() < 1)
240272
throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has empty shape");
241273
// we assume bias tensor dimension is equal to number of filters that is the second dimension in
242274
// the output tensor
243-
if (fShapeB[0] != fShapeY[1])
275+
if (!(shapeDimB[0] == fShapeY[1]))
244276
throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has wrong shape: " +
245277
ConvertShapeToString(fShapeB));
246278
if (fType != "float")
@@ -249,10 +281,11 @@ public:
249281
if (!fUseSession) {
250282
std::vector<size_t> shape(fDim + 1, 1);
251283
shape[0] = fShapeB[0];
284+
auto intTargetShape = ConvertShapeToInt(targetShape);
252285
std::shared_ptr<void> new_data_ptr(
253-
UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(original_data.get()), shape, targetShape),
286+
UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(original_data.get()), shape, intTargetShape),
254287
std::default_delete<float[]>());
255-
model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), targetShape, new_data_ptr);
288+
model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), intTargetShape, new_data_ptr);
256289
fShapeB = model.GetTensorShape(fNB);
257290
fNB2 = fNB; // use same name
258291
}
@@ -264,22 +297,27 @@ public:
264297
}
265298
}
266299
}
267-
268-
size_t outputChannelSize = fShapeY[2]; // size/channel = D * H * W
300+
// output channel size can be parametric
301+
std::vector<Dim> outputDims = std::vector<Dim>(fShapeY.begin()+2, fShapeY.end());
302+
auto outputChannelSize = ConvertDimShapeToLength(outputDims); // size/channel = D * H * W
269303
size_t kernelSize = fAttrKernelShape[0];
270304
for (size_t i = 1; i < fDim; i++) {
271-
outputChannelSize *= fShapeY[2 + i];
272305
kernelSize *= fAttrKernelShape[i];
273306
}
274307

275308
std::vector<size_t> shape1 = {fShapeW[0], fShapeW[1], kernelSize};
276-
std::vector<size_t> shape2 = {fShapeW[1], kernelSize, outputChannelSize};
309+
std::vector<Dim> shape2 = {Dim{fShapeW[1]}, Dim{kernelSize}, Dim{outputChannelSize}};
277310
model.AddIntermediateTensor(fNX +"_f", ConvertStringToType(fType), shape1 );
278311
model.AddIntermediateTensor(fNX +"_xcol", ConvertStringToType(fType), shape2 );
279312
convK = fNX +"_f";
280313
imcol = fNX +"_xcol";
281314
fOutputTensorNames.emplace_back(convK);
282315
fOutputTensorNames.emplace_back(imcol);
316+
317+
if (model.Verbose()) {
318+
std::cout << "Conv - " << fDim << " " << fNX << " : " << ConvertShapeToString(fShapeX)
319+
<< " --> " << fNY << " : " << ConvertShapeToString(fShapeY) << std::endl;
320+
}
283321
}
284322

285323
std::string GenerateInitCode() override {
@@ -289,11 +327,11 @@ public:
289327
// include a separate scope to avoid defining unique operator temp variables
290328
std::vector<size_t> shape(fDim + 1, 1);
291329
shape[0] = fShapeB[0];
292-
std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
330+
std::vector<Dim> targetShape(fShapeY.begin() + 1, fShapeY.end());
293331
out << SP << "{\n";
294332
out << SP << SP << "float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
295333
<< fNB << ", " << ConvertShapeToString(shape) << ", " << ConvertShapeToString(fShapeY) << ");\n";
296-
out << SP << SP << "std::copy(data, data + " << ConvertShapeToLength(targetShape) << ", tensor_" << fNB2 << ");\n";
334+
out << SP << SP << "std::copy(data, data + " << ConvertDimShapeToLength(targetShape) << ", tensor_" << fNB2 << ");\n";
297335
out << SP << SP << "delete[] data;\n";
298336
out << SP << "}\n";
299337
}
@@ -309,16 +347,22 @@ public:
309347
}
310348

311349
std::stringstream out;
312-
size_t bsize = fShapeX[0];
350+
auto bsize = fShapeX[0];
313351
size_t kDepth = (fDim > 2) ? fShapeW[2] : 1; // kernel depth
314352
size_t kHeight = (fDim > 1) ? fShapeW[fDim] : 1; // kernel height
315353
size_t kWidth = fShapeW[fDim+1]; // kernel width
316-
size_t iDepth = (fDim > 2) ? fShapeX[2] : 1; // input depth
317-
size_t iHeight = (fDim > 1) ? fShapeX[fDim] : 1; // input height
318-
size_t iWidth = fShapeX[fDim+1]; // input width
319-
size_t oDepth = (fDim > 2) ? fShapeY[2] : 1; // output depth
320-
size_t oHeight = (fDim > 1) ? fShapeY[fDim] : 1; // ouput height
321-
size_t oWidth = fShapeY[fDim+1]; // output width
354+
auto iDepth = (fDim > 2) ? fShapeX[2] : Dim{1}; // input depth
355+
auto iHeight = (fDim > 1) ? fShapeX[fDim] : Dim{1}; // input height
356+
auto iWidth = fShapeX[fDim+1]; // input width
357+
auto oDepth = (fDim > 2) ? fShapeY[2] : Dim{1}; // output depth
358+
auto oHeight = (fDim > 1) ? fShapeY[fDim] : Dim{1}; // ouput height
359+
auto oWidth = fShapeY[fDim+1]; // output width
360+
// total output size for a channel
361+
auto outputChannelStride = ConvertDimShapeToLength(std::vector<Dim>{oDepth, oHeight, oWidth}); // size of channel = D * H * W
362+
auto outputBatchStride = ConvertDimShapeToLength(std::vector<Dim>{fShapeY[1] , oDepth, oHeight, oWidth}); // size of C * D * H * W
363+
// input size
364+
auto inputChannelStride = ConvertDimShapeToLength(std::vector<Dim>{iDepth, iHeight, iWidth});
365+
auto inputBatchStride = ConvertDimShapeToLength(std::vector<Dim>{fShapeX[1] , iDepth, iHeight, iWidth}); // size of C * D * H * W
322366

323367
out << "\n//---- operator Conv " << OpName << "\n";
324368

@@ -366,9 +410,9 @@ public:
366410
//out << SP << "char " << OpName << "_transA = 'T';\n";
367411
out << SP << "char " << OpName << "_transA = 'N';\n";
368412
out << SP << "char " << OpName << "_transB = 'N';\n";
369-
out << SP << "int " << OpName << "_m = " << oHeight * oWidth * oDepth << ";\n"; // output h*w
413+
out << SP << "int " << OpName << "_m = " << outputChannelStride << ";\n"; // output h*w
370414
assert(fShapeY[1] == fShapeW[0]);
371-
assert(fShapeW[1] == fShapeX[1] / fAttrGroup);
415+
//assert(fShapeW[1] == fShapeX[1] / fAttrGroup);
372416
out << SP << "int " << OpName << "_n = " << fShapeW[0] << ";\n"; // output channels
373417
out << SP << "int " << OpName << "_k = " << fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] << ";\n";
374418
out << SP << "float " << OpName << "_alpha = 1.0;\n";
@@ -409,10 +453,10 @@ public:
409453
fAttrPads[2] = (fAttrPads[2] + fAttrPads[5]) / 2;
410454
}
411455
}
412-
out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
456+
out << SP << SP << "size_t out_offset = n * " << outputBatchStride << ";\n";
413457

414458
if (fAttrGroup == 1) {
415-
out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iHeight * iWidth << ";\n";
459+
out << SP << SP << "size_t x_offset = n * " << inputBatchStride << ";\n";
416460
// when using im2col - resulting matrix is transposed, the dimension is (input_c * filter_h * filter_y, output_h *
417461
// output_w)
418462
if (fDim < 3) {
@@ -456,10 +500,10 @@ public:
456500
// group)
457501
// out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
458502
out << SP << SP << "for (size_t g = 0; g < " << fAttrGroup << "; g++) {\n";
459-
out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iDepth * iHeight * iWidth << " + g * "
460-
<< fShapeW[1] * iDepth * iHeight * iWidth << ";\n ";
461-
out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << " + g * "
462-
<< fShapeW[0] * oDepth * oHeight * oWidth / fAttrGroup << ";\n ";
503+
out << SP << SP << "size_t x_offset = n * " << inputBatchStride << " + g * "
504+
<< fShapeW[1] << " * " << inputChannelStride << ";\n ";
505+
out << SP << SP << "size_t out_offset = n * " << outputBatchStride << " + g * "
506+
<< fShapeW[0] << " * (" << outputChannelStride << ") / " << fAttrGroup << ";\n ";
463507

464508
if (fDim < 3) {
465509
out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
@@ -508,7 +552,7 @@ public:
508552
}
509553

510554
if (fNB2 != "") {
511-
out << SP << "int " << OpName << "_size = " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
555+
out << SP << "int " << OpName << "_size = " << outputBatchStride << ";\n";
512556
out << SP << "float " << OpName << "_gamma = 1.0;\n";
513557
out << SP << "int " << OpName << "_incx = 1;\n";
514558
out << SP << "int " << OpName << "_incy = 1;\n";

0 commit comments

Comments
 (0)