@@ -36,10 +36,10 @@ private:
3636 std::string convK;
3737 std::string imcol;
3838
39- std::vector<size_t > fShapeX ;
39+ std::vector<Dim > fShapeX ;
4040 std::vector<size_t > fShapeW ;
4141 std::vector<size_t > fShapeB ;
42- std::vector<size_t > fShapeY ;
42+ std::vector<Dim > fShapeY ;
4343
4444 std::string fType ;
4545
@@ -93,29 +93,31 @@ public:
9393 }
9494
9595 // function returning output shape given input
96- std::vector<std::vector< size_t >> ShapeInference ( std::vector<std::vector<size_t >> input) override {
96+ std::vector<Dim> DoShapeInference ( const std::vector<Dim> & input, const std::vector<size_t > & weight) {
9797 // shape of convolution input has to be (according to ONNX): N x C x H x W
9898 // Where N : batch size, C : input channels, H : input height, W : input width
9999
100- if (input.size () > 3 ) {
101- throw
102- std::runtime_error (" TMVA SOFIE Conv Op Shape inference need 2 or 3 input tensors" );
100+ if (input.size () -2 != fDim ) {
101+ throw std::runtime_error (" TMVA SOFIE Conv Op Shape inference - invalid input " );
103102 }
104- for (size_t i = 0 ; i < input.size (); i++) {
105- if (input[i].size () -2 != fDim ) {
106- throw
107- std::runtime_error (" TMVA SOFIE Conv Op Shape inference - invalid inputs " );
108- }
103+ if (weight.size () -2 != fDim ) {
104+ throw std::runtime_error (" TMVA SOFIE Conv Op Shape inference - invalid weights " );
105+ }
106+ if (fAttrGroup == 0 && input[1 ].isParam )
107+ throw std::runtime_error (" TMVA SOFIE Conv - param shapes not supported without group attr" );
108+ if (fAttrKernelShape .empty ()) {
109+ if (input[2 ].isParam || (fDim > 1 && input[3 ].isParam ) || (fDim > 2 && input[4 ].isParam ))
110+ throw std::runtime_error (" TMVA SOFIE Conv - param shapes not supported without kernel attr" );
109111 }
110112
111113 if (fAttrGroup == 0 ) {
112- fAttrGroup = input[0 ][ 1 ] / input[ 1 ] [1 ];
114+ fAttrGroup = input[1 ]. dim / weight [1 ];
113115 }
114116
115117 // kernel shape
116- size_t k1 = ((fAttrKernelShape .empty ())? input[ 1 ] [2 ] : fAttrKernelShape [0 ]);
117- size_t k2 = (fDim > 1 ) ? ((fAttrKernelShape .empty ()) ? input[ 1 ] [3 ] : fAttrKernelShape [1 ]) : 1 ;
118- size_t k3 = (fDim > 2 ) ? ((fAttrKernelShape .empty ()) ? input[ 1 ] [4 ] : fAttrKernelShape [2 ]) : 1 ;
118+ size_t k1 = ((fAttrKernelShape .empty ())? weight [2 ] : fAttrKernelShape [0 ]);
119+ size_t k2 = (fDim > 1 ) ? ((fAttrKernelShape .empty ()) ? weight [3 ] : fAttrKernelShape [1 ]) : 1 ;
120+ size_t k3 = (fDim > 2 ) ? ((fAttrKernelShape .empty ()) ? weight [4 ] : fAttrKernelShape [2 ]) : 1 ;
119121
120122
121123 size_t i1 = (fDim > 1 ) ? ((fDim > 2 ) ? 3 : 2 ) : 1 ;
@@ -171,33 +173,62 @@ public:
171173 fAttrStrides .resize (3 , 1 );
172174
173175
174- size_t input1 = input[ 0 ] [2 ];
175- size_t input2 = (fDim > 1 ) ? input[0 ][ 3 ] : 1 ;
176- size_t input3 = (fDim > 2 ) ? input[0 ][ 4 ] : 1 ;
176+ Dim input1 = input[2 ];
177+ Dim input2 = (fDim > 1 ) ? input[3 ] : Dim{ 1 } ;
178+ Dim input3 = (fDim > 2 ) ? input[4 ] : Dim{ 1 } ;
177179
178180 size_t pad1 = fAttrPads [0 ] + fAttrPads [i1];
179- size_t output1 = (input1 + pad1 - fAttrKernelShape [0 ]) / fAttrStrides [0 ] + 1 ;
180181
181- size_t batch_size = input[0 ][0 ]; // first element in input tensor
182- size_t output_channels = input[1 ][0 ]; // first element in weight tensor
182+ // function to get output dimension of convolution given input
183+
184+ auto computeOutput = [&](Dim inputDim, size_t kernel, size_t pad, size_t stride) {
185+ if (!inputDim.isParam ) {
186+ size_t outSize = (inputDim.dim + pad - kernel) / stride + 1 ;
187+ return Dim{outSize};
188+ } else {
189+ if (stride == 1 ){
190+ if ((pad - kernel + 1 ) == 0 )
191+ // output is same as input
192+ return inputDim;
193+ else {
194+ int64_t v = pad - kernel + 1 ;
195+ std::string outStr = " (" + inputDim.param + " +" + std::to_string (v) + " )" ;
196+ return Dim{ outStr, static_cast <size_t >(-1 )};
197+ }
198+ } else { // general case (stride not 1)
199+ int64_t v = pad - kernel;
200+ std::string outStr = " ((" + inputDim.param + " +" + std::to_string (v) + " )/"
201+ + std::to_string (stride) + " 1)" ;
202+ return Dim{ outStr, static_cast <size_t >(-1 )};
203+ }
204+ }
205+ std::runtime_error (" TMVA SOFIE Conv Op - invalid values" );
206+ return Dim{};
207+ };
208+
209+ Dim output1 = computeOutput (input1, fAttrKernelShape [0 ], pad1, fAttrStrides [0 ]);
183210
184- std::vector<std::vector<size_t >> ret ({{ batch_size, output_channels, output1 }});
211+ Dim batch_size = input[0 ]; // first element in input tensor
212+ Dim output_channels = Dim{weight[0 ]}; // first element in weight tensor
213+
214+ std::vector<Dim> ret ({ batch_size, output_channels, output1 });
185215
186216 if (fDim == 1 )
187217 return ret;
188218
189219 size_t pad2 = fAttrPads [1 ] + fAttrPads [i2];
190- size_t output2 = (input2 + pad2 - fAttrKernelShape [1 ]) / fAttrStrides [1 ] + 1 ;
220+ Dim output2 = computeOutput (input2, fAttrKernelShape [1 ], pad2, fAttrStrides [1 ]);
221+
191222 // output is N x M x OH x OW
192- ret[ 0 ] .push_back (output2);
223+ ret.push_back (output2);
193224 if (fDim == 2 )
194225 return ret;
195226
196227 size_t pad3 = fAttrPads [2 ] + fAttrPads [i3];
197- size_t output3 = (input3 + pad3 - fAttrKernelShape [2 ] ) / fAttrStrides [2 ] + 1 ;
228+ Dim output3 = computeOutput (input3, fAttrKernelShape [2 ], pad3, fAttrStrides [2 ]) ;
198229
199230 // output is N x M x OH x OW x OD
200- ret[ 0 ] .push_back (output3);
231+ ret.push_back (output3);
201232 return ret;
202233 }
203234
@@ -207,7 +238,7 @@ public:
207238 throw
208239 std::runtime_error (" TMVA SOFIE Conv op Input Tensor " + fNX + " is not found in model" );
209240 }
210- fShapeX = model.GetTensorShape (fNX );
241+ fShapeX = model.GetDimTensorShape (fNX );
211242 if (fShapeX .size () < 3 || fShapeX .size () > 5 ) {
212243 std::cout << fNX << " : " << ConvertShapeToString (fShapeX ) << std::endl;
213244 throw
@@ -223,24 +254,25 @@ public:
223254 std::cout << fNW << " : " << ConvertShapeToString (fShapeW ) << std::endl;
224255 throw std::runtime_error (" TMVA SOFIE Conv Op input weight tensor" + fNW + " is not of 3,4 or 5 dimensions" );
225256 }
226- fShapeY = ShapeInference ({ fShapeX , fShapeW })[ 0 ] ;
257+ fShapeY = DoShapeInference ( fShapeX , fShapeW ) ;
227258 model.AddIntermediateTensor (fNY , model.GetTensorType (fNX ), fShapeY );
228259 if (fNB != " " ) {
229260 if (!model.CheckIfTensorAlreadyExist (fNB )) {
230261 throw
231262 std::runtime_error (" TMVA SOFIE Conv op Input Tensor " + fNB + " is not found in model" );
232263 }
233264 fShapeB = model.GetTensorShape (fNB );
234- std::vector<size_t > targetShape (fShapeY .begin () + 1 , fShapeY .end ());
235- bool broadcast_needed = !UTILITY::AreSameShape (fShapeB , targetShape);
265+ std::vector<Dim> targetShape (fShapeY .begin () + 1 , fShapeY .end ());
266+ auto shapeDimB = model.GetDimTensorShape (fNB );
267+ bool broadcast_needed = !UTILITY::AreSameShape (shapeDimB, targetShape);
236268 if (broadcast_needed) {
237269 auto original_data = model.GetInitializedTensorData (fNB );
238270 // make bias shape equal to Y shape by adding 1
239271 if (fShapeB .size () < 1 )
240272 throw std::runtime_error (" TMVA SOFIE Conv op: Bias Tensor has empty shape" );
241273 // we assume bias tensor dimension is equal to number of filters that is the second dimension in
242274 // the output tensor
243- if (fShapeB [0 ] != fShapeY [1 ])
275+ if (!(shapeDimB [0 ] == fShapeY [1 ]) )
244276 throw std::runtime_error (" TMVA SOFIE Conv op: Bias Tensor has wrong shape: " +
245277 ConvertShapeToString (fShapeB ));
246278 if (fType != " float" )
@@ -249,10 +281,11 @@ public:
249281 if (!fUseSession ) {
250282 std::vector<size_t > shape (fDim + 1 , 1 );
251283 shape[0 ] = fShapeB [0 ];
284+ auto intTargetShape = ConvertShapeToInt (targetShape);
252285 std::shared_ptr<void > new_data_ptr (
253- UTILITY::UnidirectionalBroadcast<float >(static_cast <float *>(original_data.get ()), shape, targetShape ),
286+ UTILITY::UnidirectionalBroadcast<float >(static_cast <float *>(original_data.get ()), shape, intTargetShape ),
254287 std::default_delete<float []>());
255- model.UpdateInitializedTensor (fNB , model.GetTensorType (fNB ), targetShape , new_data_ptr);
288+ model.UpdateInitializedTensor (fNB , model.GetTensorType (fNB ), intTargetShape , new_data_ptr);
256289 fShapeB = model.GetTensorShape (fNB );
257290 fNB2 = fNB ; // use same name
258291 }
@@ -264,22 +297,27 @@ public:
264297 }
265298 }
266299 }
267-
268- size_t outputChannelSize = fShapeY [2 ]; // size/channel = D * H * W
300+ // output channel size can be parametric
301+ std::vector<Dim> outputDims = std::vector<Dim>(fShapeY .begin ()+2 , fShapeY .end ());
302+ auto outputChannelSize = ConvertDimShapeToLength (outputDims); // size/channel = D * H * W
269303 size_t kernelSize = fAttrKernelShape [0 ];
270304 for (size_t i = 1 ; i < fDim ; i++) {
271- outputChannelSize *= fShapeY [2 + i];
272305 kernelSize *= fAttrKernelShape [i];
273306 }
274307
275308 std::vector<size_t > shape1 = {fShapeW [0 ], fShapeW [1 ], kernelSize};
276- std::vector<size_t > shape2 = {fShapeW [1 ], kernelSize, outputChannelSize};
309+ std::vector<Dim > shape2 = {Dim{ fShapeW [1 ]}, Dim{ kernelSize}, Dim{ outputChannelSize} };
277310 model.AddIntermediateTensor (fNX +" _f" , ConvertStringToType (fType ), shape1 );
278311 model.AddIntermediateTensor (fNX +" _xcol" , ConvertStringToType (fType ), shape2 );
279312 convK = fNX +" _f" ;
280313 imcol = fNX +" _xcol" ;
281314 fOutputTensorNames .emplace_back (convK);
282315 fOutputTensorNames .emplace_back (imcol);
316+
317+ if (model.Verbose ()) {
318+ std::cout << " Conv - " << fDim << " " << fNX << " : " << ConvertShapeToString (fShapeX )
319+ << " --> " << fNY << " : " << ConvertShapeToString (fShapeY ) << std::endl;
320+ }
283321 }
284322
285323 std::string GenerateInitCode () override {
@@ -289,11 +327,11 @@ public:
289327 // include a separate scope to avoid defining unique operator temp variables
290328 std::vector<size_t > shape (fDim + 1 , 1 );
291329 shape[0 ] = fShapeB [0 ];
292- std::vector<size_t > targetShape (fShapeY .begin () + 1 , fShapeY .end ());
330+ std::vector<Dim > targetShape (fShapeY .begin () + 1 , fShapeY .end ());
293331 out << SP << " {\n " ;
294332 out << SP << SP << " float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
295333 << fNB << " , " << ConvertShapeToString (shape) << " , " << ConvertShapeToString (fShapeY ) << " );\n " ;
296- out << SP << SP << " std::copy(data, data + " << ConvertShapeToLength (targetShape) << " , tensor_" << fNB2 << " );\n " ;
334+ out << SP << SP << " std::copy(data, data + " << ConvertDimShapeToLength (targetShape) << " , tensor_" << fNB2 << " );\n " ;
297335 out << SP << SP << " delete[] data;\n " ;
298336 out << SP << " }\n " ;
299337 }
@@ -309,16 +347,22 @@ public:
309347 }
310348
311349 std::stringstream out;
312- size_t bsize = fShapeX [0 ];
350+ auto bsize = fShapeX [0 ];
313351 size_t kDepth = (fDim > 2 ) ? fShapeW [2 ] : 1 ; // kernel depth
314352 size_t kHeight = (fDim > 1 ) ? fShapeW [fDim ] : 1 ; // kernel height
315353 size_t kWidth = fShapeW [fDim +1 ]; // kernel width
316- size_t iDepth = (fDim > 2 ) ? fShapeX [2 ] : 1 ; // input depth
317- size_t iHeight = (fDim > 1 ) ? fShapeX [fDim ] : 1 ; // input height
318- size_t iWidth = fShapeX [fDim +1 ]; // input width
319- size_t oDepth = (fDim > 2 ) ? fShapeY [2 ] : 1 ; // output depth
320- size_t oHeight = (fDim > 1 ) ? fShapeY [fDim ] : 1 ; // ouput height
321- size_t oWidth = fShapeY [fDim +1 ]; // output width
354+ auto iDepth = (fDim > 2 ) ? fShapeX [2 ] : Dim{1 }; // input depth
355+ auto iHeight = (fDim > 1 ) ? fShapeX [fDim ] : Dim{1 }; // input height
356+ auto iWidth = fShapeX [fDim +1 ]; // input width
357+ auto oDepth = (fDim > 2 ) ? fShapeY [2 ] : Dim{1 }; // output depth
358+ auto oHeight = (fDim > 1 ) ? fShapeY [fDim ] : Dim{1 }; // ouput height
359+ auto oWidth = fShapeY [fDim +1 ]; // output width
360+ // total output size for a channel
361+ auto outputChannelStride = ConvertDimShapeToLength (std::vector<Dim>{oDepth, oHeight, oWidth}); // size of channel = D * H * W
362+ auto outputBatchStride = ConvertDimShapeToLength (std::vector<Dim>{fShapeY [1 ] , oDepth, oHeight, oWidth}); // size of C * D * H * W
363+ // input size
364+ auto inputChannelStride = ConvertDimShapeToLength (std::vector<Dim>{iDepth, iHeight, iWidth});
365+ auto inputBatchStride = ConvertDimShapeToLength (std::vector<Dim>{fShapeX [1 ] , iDepth, iHeight, iWidth}); // size of C * D * H * W
322366
323367 out << " \n //---- operator Conv " << OpName << " \n " ;
324368
@@ -366,9 +410,9 @@ public:
366410 // out << SP << "char " << OpName << "_transA = 'T';\n";
367411 out << SP << " char " << OpName << " _transA = 'N';\n " ;
368412 out << SP << " char " << OpName << " _transB = 'N';\n " ;
369- out << SP << " int " << OpName << " _m = " << oHeight * oWidth * oDepth << " ;\n " ; // output h*w
413+ out << SP << " int " << OpName << " _m = " << outputChannelStride << " ;\n " ; // output h*w
370414 assert (fShapeY [1 ] == fShapeW [0 ]);
371- assert (fShapeW [1 ] == fShapeX [1 ] / fAttrGroup );
415+ // assert(fShapeW[1] == fShapeX[1] / fAttrGroup);
372416 out << SP << " int " << OpName << " _n = " << fShapeW [0 ] << " ;\n " ; // output channels
373417 out << SP << " int " << OpName << " _k = " << fShapeW [1 ] * fAttrKernelShape [0 ] * fAttrKernelShape [1 ] * fAttrKernelShape [2 ] << " ;\n " ;
374418 out << SP << " float " << OpName << " _alpha = 1.0;\n " ;
@@ -409,10 +453,10 @@ public:
409453 fAttrPads [2 ] = (fAttrPads [2 ] + fAttrPads [5 ]) / 2 ;
410454 }
411455 }
412- out << SP << SP << " size_t out_offset = n * " << fShapeY [ 1 ] * oDepth * oHeight * oWidth << " ;\n " ;
456+ out << SP << SP << " size_t out_offset = n * " << outputBatchStride << " ;\n " ;
413457
414458 if (fAttrGroup == 1 ) {
415- out << SP << SP << " size_t x_offset = n * " << fShapeX [ 1 ] * iHeight * iWidth << " ;\n " ;
459+ out << SP << SP << " size_t x_offset = n * " << inputBatchStride << " ;\n " ;
416460 // when using im2col - resulting matrix is transposed, the dimension is (input_c * filter_h * filter_y, output_h *
417461 // output_w)
418462 if (fDim < 3 ) {
@@ -456,10 +500,10 @@ public:
456500 // group)
457501 // out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
458502 out << SP << SP << " for (size_t g = 0; g < " << fAttrGroup << " ; g++) {\n " ;
459- out << SP << SP << " size_t x_offset = n * " << fShapeX [ 1 ] * iDepth * iHeight * iWidth << " + g * "
460- << fShapeW [1 ] * iDepth * iHeight * iWidth << " ;\n " ;
461- out << SP << SP << " size_t out_offset = n * " << fShapeY [ 1 ] * oDepth * oHeight * oWidth << " + g * "
462- << fShapeW [0 ] * oDepth * oHeight * oWidth / fAttrGroup << " ;\n " ;
503+ out << SP << SP << " size_t x_offset = n * " << inputBatchStride << " + g * "
504+ << fShapeW [1 ] << " * " << inputChannelStride << " ;\n " ;
505+ out << SP << SP << " size_t out_offset = n * " << outputBatchStride << " + g * "
506+ << fShapeW [0 ] << " * ( " << outputChannelStride << " ) / " << fAttrGroup << " ;\n " ;
463507
464508 if (fDim < 3 ) {
465509 out << SP << SP << " TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
@@ -508,7 +552,7 @@ public:
508552 }
509553
510554 if (fNB2 != " " ) {
511- out << SP << " int " << OpName << " _size = " << fShapeY [ 1 ] * oDepth * oHeight * oWidth << " ;\n " ;
555+ out << SP << " int " << OpName << " _size = " << outputBatchStride << " ;\n " ;
512556 out << SP << " float " << OpName << " _gamma = 1.0;\n " ;
513557 out << SP << " int " << OpName << " _incx = 1;\n " ;
514558 out << SP << " int " << OpName << " _incy = 1;\n " ;
0 commit comments