@@ -309,55 +309,73 @@ public:
309
309
// we allocate here output vector
310
310
out << SP << SP << " if (" << fDimShapeA [i] << " != " << fDimShapeB [i] << " && (" << fDimShapeA [i]
311
311
<< " != 1 || " << fDimShapeB [i] << " != 1))\n " ;
312
- out << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast shapes in operator " << opName
312
+ out << SP << SP << SP << " throw std::runtime_error(\" SOFIE - Cannot broadcast shapes in operator " << opName
313
313
<< " \" );\n " ;
314
314
}
315
315
}
316
+ out << SP << " }\n " ;
316
317
}
317
318
318
- auto stridesA = UTILITY::ComputeStrideFromShape (fShapeA );
319
- auto stridesB = UTILITY::ComputeStrideFromShape (fShapeB );
320
- auto stridesY = UTILITY::ComputeStrideFromShape (fShapeY );
319
+ auto stridesA = UTILITY::ComputeStrideFromShape (fDimShapeA );
320
+ auto stridesB = UTILITY::ComputeStrideFromShape (fDimShapeB );
321
+ auto stridesY = UTILITY::ComputeStrideFromShape (fDimShapeY );
321
322
322
323
std::string compute_idx_A, compute_idx_B, compute_idx_Y;
323
- if (std::all_of (fShapeA .begin (), fShapeA .end (), [](size_t x ) { return x == 1 ; })) {
324
+ if (std::all_of (fDimShapeA .begin (), fDimShapeA .end (), [](Dim d ) { return d. dim == 1 || d. GetVal () == " 1 " ; })) {
324
325
compute_idx_A = " 0" ;
325
326
} else {
326
- for (size_t i = 0 ; i < fShapeA .size (); ++i) {
327
- if (fShapeA [i] == 1 )
327
+ for (size_t i = 0 ; i < fDimShapeA .size (); ++i) {
328
+ if (fDimShapeA [i]. dim == 1 || fDimShapeA [i]. GetVal () == " 1 " )
328
329
continue ;
329
- compute_idx_A +=
330
- " idx_" + std::to_string (i + (fShapeY .size () - fShapeA .size ())) + " * " + stridesA[i] + " +" ;
330
+ compute_idx_A += " idx_" + std::to_string (i + (fDimShapeY .size () - fDimShapeA .size ()));
331
+ if (stridesA[i].GetVal () != " 1" )
332
+ compute_idx_A += " * " + stridesA[i].GetVal ();
333
+ compute_idx_A += " + " ;
331
334
}
332
- compute_idx_A.pop_back ();
335
+ // remove last 3 character " + "
336
+ for (int j = 0 ; j < 3 ; j++)
337
+ compute_idx_A.pop_back ();
333
338
}
334
- if (std::all_of (fShapeB .begin (), fShapeB .end (), [](size_t x ) { return x == 1 ; })) {
339
+ if (std::all_of (fDimShapeB .begin (), fDimShapeB .end (), [](Dim d ) { return d. dim == 1 || d. GetVal () == " 1 " ; })) {
335
340
compute_idx_B = " 0" ;
336
341
} else {
337
- for (size_t i = 0 ; i < fShapeB .size (); ++i) {
338
- if (fShapeB [i] == 1 )
342
+ for (size_t i = 0 ; i < fDimShapeB .size (); ++i) {
343
+ if (fDimShapeB [i]. dim == 1 || fDimShapeB [i]. GetVal () == " 1 " )
339
344
continue ;
340
- compute_idx_B +=
341
- " idx_" + std::to_string (i + (fShapeY .size () - fShapeB .size ())) + " * " + stridesB[i] + " +" ;
345
+ compute_idx_B += " idx_" + std::to_string (i + (fDimShapeY .size () - fDimShapeB .size ()));
346
+ if (stridesB[i].GetVal () != " 1" )
347
+ compute_idx_B += " * " + stridesB[i].GetVal ();
348
+ compute_idx_B += " + " ;
342
349
}
343
- compute_idx_B.pop_back ();
350
+ // remove last 3 character " + "
351
+ for (int j = 0 ; j < 3 ; j++)
352
+ compute_idx_B.pop_back ();
344
353
}
345
- for (size_t i = 0 ; i < fShapeY .size (); ++i) {
346
- if (fShapeY [i] != 1 ) {
347
- out << std::string (i + 1 , ' ' ) << " for(size_t idx_" << i << " =0; idx_" << i << " <" << fShapeY [i]
354
+ int nloop = 0 ;
355
+ for (size_t i = 0 ; i < fDimShapeY .size (); ++i) {
356
+ if (fDimShapeY [i].dim != 1 && fDimShapeY [i].GetVal () != " 1" ) {
357
+ nloop++;
358
+ for (int j = 0 ; j < nloop; j++) out << SP;
359
+ out << " for (size_t idx_" << i << " = 0; idx_" << i << " < " << fDimShapeY [i]
348
360
<< " ; ++idx_" << i << " ){\n " ;
349
- compute_idx_Y += " idx_" + std::to_string (i) + " *" + stridesY[i] + " +" ;
361
+ compute_idx_Y += " idx_" + std::to_string (i);
362
+ if (stridesY[i].GetVal () != " 1" )
363
+ compute_idx_Y += " * " + stridesY[i].GetVal ();
364
+ compute_idx_Y += " + " ;
350
365
}
351
366
}
352
- compute_idx_Y.pop_back ();
353
- out << SP << SP << " tensor_" << fNY << " [" << compute_idx_Y << " ] = "
367
+ // remove last 3 characters " + "
368
+ for (int j = 0 ; j < 3 ; j++)
369
+ compute_idx_Y.pop_back ();
370
+ for (int j = 0 ; j < nloop+1 ; j++) out << SP;
371
+ out << " tensor_" << fNY << " [" << compute_idx_Y << " ] = "
354
372
<< BinaryOperatorTrait<T, Op>::Op (" tensor_" + fNA + " [" + compute_idx_A + " ]" ,
355
373
" tensor_" + fNB + " [" + compute_idx_B + " ]" )
356
374
<< " ;\n " ;
357
- for ( size_t i = 0 ; i < fShapeY . size (); ++i) {
358
- if ( fShapeY [i] != 1 ) {
359
- out << std::string ( fShapeY . size () - i + 1 , ' ' ) << " } \n " ;
360
- }
375
+
376
+ for ( int i = nloop; i > 0 ; i-- ) {
377
+ for ( int j = 0 ; j < i; j++) out << SP ;
378
+ out << " } \n " ;
361
379
}
362
380
return out.str ();
363
381
}
0 commit comments