-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathRModel.hxx
More file actions
122 lines (92 loc) · 3.63 KB
/
RModel.hxx
File metadata and controls
122 lines (92 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#ifndef TMVA_SOFIE_RMODEL
#define TMVA_SOFIE_RMODEL
#include <vector>
#include <unordered_map>
#include <iostream>
#include <memory>
#include <ctime>
#include <set>
#include <iomanip>
#include <fstream>
#include <sstream>
#include "SOFIE_common.hxx"
#include "ROperator.hxx"
namespace TMVA{
namespace Experimental{
namespace SOFIE{
class RModel{
private:
std::unordered_map<std::string, InputTensorInfo> fInputTensorInfos; //graph input only; not including operator input (intermediate tensors)
std::unordered_map<std::string, TensorInfo> fReadyInputTensorInfos;
std::unordered_map<std::string, InitializedTensor> fInitializedTensors;
std::unordered_map<std::string, TensorInfo> fIntermediateTensorInfos;
std::vector<std::string> fOutputTensorNames;
std::vector<std::unique_ptr<ROperator>> fOperators;
std::string fName="UnnamedModel";
std::string fFileName; //file name of original model file for identification
std::string fParseTime; //UTC date and time string at parsing
std::string fGC; //generated code
bool fNeedGemm = true;
const std::vector<std::string> fAllowedStdLib = {"algorithm"};
std::set<std::string> fNeededStdLib = {"vector"};
public:
//explicit move ctor/assn
RModel(RModel&& other);
RModel& operator=(RModel&& other);
//disallow copy
RModel(const RModel& other) = delete;
RModel& operator=(const RModel& other) = delete;
RModel(){}
RModel(std::string name, std::string parsedtime);
const std::vector<size_t>& GetTensorShape(std::string name);
const ETensorType& GetTensorType(std::string name);
bool CheckIfTensorAlreadyExist(std::string tensor_name);
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<Dim> shape);
void AddInputTensorInfo(std::string input_name, ETensorType type, std::vector<size_t> shape);
void AddOperator(std::unique_ptr<ROperator> op, int order_execution = -1);
void AddInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape, std::shared_ptr<void> data);
void AddIntermediateTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape);
void AddNeededStdLib(std::string libname){
for (auto& i: fAllowedStdLib){
if ( i == libname) fNeededStdLib.insert(libname);
}
}
void AddOutputTensorNameList(std::vector<std::string> outputtensornames){
fOutputTensorNames = outputtensornames;
}
void UpdateInitializedTensor(std::string tensor_name, ETensorType type, std::vector<std::size_t> shape, std::shared_ptr<void> data);
std::shared_ptr<void> GetInitializedTensorData(std::string tensor_name);
void Initialize();
void Generate();
void PrintGenerated(){
std::cout << fGC;
}
void PrintIntermediateTensors();
void OutputGenerated(std::string filename = "");
/*
template <typename T>
void AddInitializedTensor(std::string tensor_name, RTensor<T> new_tensor){
//a view only
T obj;
if (fInitializedTensors.find(tensor_name) != fInitializedTensors.end()){
throw std::runtime_error("TMVA-SOFIE: initialized tensor with name " + tensor_name + " already exists \n");
}
InitializedTensor new_tensor_ {GetTemplatedType(obj), new_tensor.GetShape() , static_cast<void>(new_tensor.GetData())};
fInitializedTensors[tensor_name] = new_tensor_;
}
*/
void PrintRequiredInputTensors();
void PrintInitializedTensors();
void HeadInitializedTensors(std::string name, int n_print = 50);
~RModel(){
/*
for (auto& i: fInitializedTensors){
free(i.second.data);
}
*/
}
};
}//SOFIE
}//Experimental
}//TMVA
#endif //TMVA_SOFIE_RMODEL