@@ -387,7 +387,6 @@ void RModel::CheckAndFuseOperators() {
387387 std::vector<size_t > fusable_indices;
388388 std::string fusable_propagate_tensor_name;
389389 while (idx < fOperators .size ()) {
390- std::cout<<" \n op currently: " <<toString (fOperators [idx]->GetOpKind ());
391390 if (fOperators [idx]->GetOpKind () != OperatorKind::GEMM && fOperators [idx]->GetOpKind () != OperatorKind::CONV) {
392391 ++idx;
393392 continue ;
@@ -400,13 +399,11 @@ void RModel::CheckAndFuseOperators() {
400399 size_t j = idx + 1 ;
401400 for (; j < fOperators .size ()-1 ; ++j) {
402401 auto opKind = fOperators [j]->GetOpKind ();
403- std::cout<<" \n checking for fusion: " <<toString (opKind);
404402 // Only consider operators with fusable kinds
405403 if (!FusableKinds.count (opKind)) {
406404 // std::cout<<"\n op not fusable: "<<toString(opKind);
407405 break ;
408406 }
409- std::cout<<" \n might be fusable: " <<toString (opKind);
410407
411408 const auto & tensorName = fOperators [j]->GetFusableOutputTensorName ();
412409 auto freqIt = fIntermediateTensorFrequencyLookup .find (tensorName);
@@ -421,21 +418,21 @@ void RModel::CheckAndFuseOperators() {
421418 break ;
422419 }
423420 }
424- // std::cout<<"\nstart fusing: "<<fusable_propagate_tensor_name;
425421 if (!fusable_propagate_tensor_name.empty ()) {
426- // std::cout << "\nOperators to be fused with: " << fusable_propagate_tensor_name;
422+ auto fusable_tensor_type = GetTensorType (fusable_propagate_tensor_name);
423+ auto fusable_tensor_shape = GetDynamicTensorShape (fusable_propagate_tensor_name);
427424 for (auto & index : fusable_indices) {
428- std::cout<<" \n fusing op " <<toString (fOperators [index]->GetOpKind ())<<" , with: " <<fusable_propagate_tensor_name;
429- fOperators [index]->UpdateFusableTensorName (fusable_propagate_tensor_name);
425+ fOperators [index]->UpdateFusableTensorName (fusable_propagate_tensor_name, [this ](const std::string& name) {
426+ this ->RemoveIntermediateTensor (name);
427+ });
430428 }
429+ AddIntermediateTensor (fusable_propagate_tensor_name, fusable_tensor_type, fusable_tensor_shape);
431430 }
432431
433432 idx = std::max (idx + 1 , j);
434433 }
435434}
436435
437-
438-
439436void RModel::Initialize (int batchSize, bool verbose) {
440437 std::map<std::string, size_t > inputParams;
441438 if (batchSize > 0 ) {
0 commit comments