Skip to content

Commit 291154b

Browse files
committed
fix: avoid allocating memory again to tensors of fused ops
1 parent dd1159f commit 291154b

File tree

8 files changed

+43
-53
lines changed

8 files changed

+43
-53
lines changed

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ public:
147147

148148
void SetOptimizationLevel(const OptimizationLevel &optim_level) { fOptimizationLevel = optim_level; }
149149

150+
void RemoveIntermediateTensor(const std::string& tensor_name){
151+
fIntermediateTensorInfos.erase(tensor_name);
152+
}
153+
150154
protected:
151155
// internal functions
152156
// generate code for the initialized tensors

tmva/sofie/inc/TMVA/ROperator.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public:
6060
virtual std::string GenerateSessionMembersCode(std::string /*opName*/) { return ""; }
6161
virtual std::string Header() { return "";}
6262
virtual std::string GetFusableOutputTensorName() { return "";}
63-
virtual void UpdateFusableTensorName(std::string){ return;};
63+
virtual void UpdateFusableTensorName(std::string, const std::function<void(const std::string&)>& removal_func){ return;};
6464

6565

6666
//virtual void Forward_reference() = 0;

tmva/sofie/inc/TMVA/ROperator_BatchNormalization.hxx

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,15 @@ public:
235235
return fNY;
236236
}
237237

238-
void UpdateFusableTensorName(std::string fusable_tensor_name){
239-
fNX = UTILITY::Clean_name(fusable_tensor_name);
240-
fNY = UTILITY::Clean_name(fusable_tensor_name);
241-
fInputTensorNames = { fNX, fNScale };
242-
if (!fNB.empty()){
243-
fInputTensorNames.emplace_back(fNB);
244-
}
245-
246-
fOutputTensorNames = { fNY };
247-
std::cout<<"\ncalled from gemm";
238+
void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function<void(const std::string&)>& removal_func){
239+
removal_func(fNX);
240+
removal_func(fNY);
241+
fNX = fusable_tensor_name;
242+
fNY = fusable_tensor_name;
243+
fInputTensorNames[0] = fNX;
244+
fOutputTensorNames[0] = fNY;
245+
}
248246

249-
}
250247
};
251248

252249
}//SOFIE

tmva/sofie/inc/TMVA/ROperator_Conv.hxx

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,11 @@ public:
530530
std::string GetFusableOutputTensorName() override {
531531
return fNY;
532532
}
533-
void UpdateFusableTensorName(std::string fusable_tensor_name) override {
534-
std::cout<<"\ncalled from conv";
533+
void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function<void(const std::string&)>& removal_func) override {
534+
removal_func(fNY);
535535
fNY = fusable_tensor_name;
536-
fOutputTensorNames = { fNY };
537-
convK = fNX +"_f";
538-
imcol = fNX +"_xcol";
539-
fOutputTensorNames.emplace_back(convK);
540-
fOutputTensorNames.emplace_back(imcol);
541-
}
536+
fOutputTensorNames[0] = fNY;
537+
}
542538
};
543539

544540
} // namespace SOFIE

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,12 @@ namespace SOFIE{
390390
return fNY;
391391
}
392392

393-
void UpdateFusableTensorName(std::string fusable_tensor_name){
394-
fNY = UTILITY::Clean_name(fusable_tensor_name);
395-
fOutputTensorNames = { fNY };
396-
std::cout<<"\ncalled from gemm";
393+
void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function<void(const std::string&)>& removal_func){
394+
removal_func(fNY);
395+
fNY = fusable_tensor_name;
396+
fOutputTensorNames[0] = fNY;
397397
}
398+
398399
};
399400

400401

tmva/sofie/inc/TMVA/ROperator_LayerNormalization.hxx

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -342,22 +342,15 @@ public:
342342
return fNY;
343343
}
344344

345-
void UpdateFusableTensorName(std::string fusable_tensor_name){
346-
fNX = UTILITY::Clean_name(fusable_tensor_name);
347-
fNY = UTILITY::Clean_name(fusable_tensor_name);
348-
fInputTensorNames = { fNX, fNScale };
349-
if (!fNB.empty()){
350-
fInputTensorNames.emplace_back(fNB);
351-
}
352-
353-
fOutputTensorNames = { fNY };
354-
if (!fNMean.empty()){
355-
fOutputTensorNames.emplace_back(fNMean);
356-
}
357-
if (!fNInvStdDev.empty()){
358-
fOutputTensorNames.emplace_back(fNInvStdDev);
359-
}
345+
void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function<void(const std::string&)>& removal_func){
346+
removal_func(fNX);
347+
removal_func(fNY);
348+
fNX = fusable_tensor_name;
349+
fNY = fusable_tensor_name;
350+
fInputTensorNames[0] = fNX;
351+
fOutputTensorNames[0] = fNY;
360352
}
353+
361354
};
362355

363356
} // namespace SOFIE

tmva/sofie/inc/TMVA/ROperator_Relu.hxx

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ public:
7272
return fNY;
7373
}
7474

75-
void UpdateFusableTensorName(std::string fusable_tensor_name){
76-
fNX = fusable_tensor_name;
77-
fNY = fusable_tensor_name;
78-
fInputTensorNames = { fNX };
79-
fOutputTensorNames = { fNY };
75+
void UpdateFusableTensorName(std::string fusable_tensor_name, const std::function<void(const std::string&)>& removal_func){
76+
removal_func(fNX);
77+
removal_func(fNY);
78+
fNX = fusable_tensor_name;
79+
fNY = fusable_tensor_name;
80+
fInputTensorNames[0] = fNX;
81+
fOutputTensorNames[0] = fNY;
8082
}
8183

8284
};

tmva/sofie/src/RModel.cxx

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ void RModel::CheckAndFuseOperators() {
387387
std::vector<size_t> fusable_indices;
388388
std::string fusable_propagate_tensor_name;
389389
while (idx < fOperators.size()) {
390-
std::cout<<"\nop currently: "<<toString(fOperators[idx]->GetOpKind());
391390
if (fOperators[idx]->GetOpKind() != OperatorKind::GEMM && fOperators[idx]->GetOpKind() != OperatorKind::CONV) {
392391
++idx;
393392
continue;
@@ -400,13 +399,11 @@ void RModel::CheckAndFuseOperators() {
400399
size_t j = idx + 1;
401400
for (; j < fOperators.size()-1; ++j) {
402401
auto opKind = fOperators[j]->GetOpKind();
403-
std::cout<<"\nchecking for fusion: "<<toString(opKind);
404402
// Only consider operators with fusable kinds
405403
if (!FusableKinds.count(opKind)) {
406404
// std::cout<<"\n op not fusable: "<<toString(opKind);
407405
break;
408406
}
409-
std::cout<<"\nmight be fusable: "<<toString(opKind);
410407

411408
const auto& tensorName = fOperators[j]->GetFusableOutputTensorName();
412409
auto freqIt = fIntermediateTensorFrequencyLookup.find(tensorName);
@@ -421,21 +418,21 @@ void RModel::CheckAndFuseOperators() {
421418
break;
422419
}
423420
}
424-
// std::cout<<"\nstart fusing: "<<fusable_propagate_tensor_name;
425421
if (!fusable_propagate_tensor_name.empty()) {
426-
// std::cout << "\nOperators to be fused with: " << fusable_propagate_tensor_name;
422+
auto fusable_tensor_type = GetTensorType(fusable_propagate_tensor_name);
423+
auto fusable_tensor_shape = GetDynamicTensorShape(fusable_propagate_tensor_name);
427424
for (auto& index : fusable_indices) {
428-
std::cout<<"\nfusing op "<<toString(fOperators[index]->GetOpKind())<<" , with: "<<fusable_propagate_tensor_name;
429-
fOperators[index]->UpdateFusableTensorName(fusable_propagate_tensor_name);
425+
fOperators[index]->UpdateFusableTensorName(fusable_propagate_tensor_name, [this](const std::string& name) {
426+
this->RemoveIntermediateTensor(name);
427+
});
430428
}
429+
AddIntermediateTensor(fusable_propagate_tensor_name, fusable_tensor_type, fusable_tensor_shape);
431430
}
432431

433432
idx = std::max(idx + 1, j);
434433
}
435434
}
436435

437-
438-
439436
void RModel::Initialize(int batchSize, bool verbose) {
440437
std::map<std::string, size_t> inputParams;
441438
if (batchSize > 0) {

0 commit comments

Comments
 (0)