|
56 | 56 | throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value "); |
57 | 57 |
|
58 | 58 | int concat_dim=0; |
| 59 | + // case of Concat (fNewAxis = 0) and not ConcatFromSequence |
59 | 60 | if(fnewAxis == 0){ |
60 | 61 | for (size_t i = 0; i < inputs.size(); i++) { |
61 | 62 | if (i > 0 && inputs[i].size() != inputs[i - 1].size()) |
|
76 | 77 | ret[0][fAxis] = concat_dim; |
77 | 78 | } |
78 | 79 | std::vector<int> stack; |
| 80 | + // case ConCatFromSequence |
79 | 81 | if(fnewAxis == 1){ |
80 | 82 | for(size_t i = 0; i < inputs.size(); i++) { |
81 | 83 | if (i > 0 && inputs[i].size() != inputs[i-1].size() ) |
|
99 | 101 | } |
100 | 102 |
|
101 | 103 | // get shape of output given inputs. It is going to be called after initialized |
102 | | - std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & inputs) { |
103 | | - std::vector<std::vector<Dim>> ret(1); |
| 104 | + std::vector<Dim> ShapeInference(const std::vector<std::vector<Dim>> & inputs, const RModel & model) { |
| 105 | + std::vector<Dim> ret(inputs[0].size()); |
104 | 106 | // treat negative axis case |
105 | 107 | if (fAxis<0) { |
106 | 108 | fAxis = inputs[0].size()+fAxis; |
107 | 109 | } |
108 | 110 | if (fAxis < 0 || fAxis >= (int) inputs[0].size()) |
109 | 111 | throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value "); |
110 | 112 |
|
111 | | - std::string concat_dim; |
112 | | - size_t i_concat_dim = 0; |
| 113 | + Dim concat_dim; |
113 | 114 | if(fnewAxis == 0){ |
114 | 115 | for (size_t i = 0; i < inputs.size(); i++) { |
115 | 116 | if (i > 0 && inputs[i].size() != inputs[i - 1].size()) |
116 | 117 | throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " + |
117 | 118 | ConvertShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertShapeToString(inputs[i - 1])); |
118 | 119 | for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) { |
119 | 120 | if ((int)iaxis == fAxis) { |
120 | | - // support only non-params shape for the concatenation axis |
121 | | - if (inputs[i][iaxis].isParam) { |
122 | | - if (concat_dim.empty()) |
123 | | - concat_dim = inputs[i][iaxis].GetVal(); |
124 | | - else |
125 | | - concat_dim += std::string("+ ") + inputs[i][iaxis].GetVal(); |
| 121 | + // support both integer and params shape for the concatenation axis |
| 122 | + if (concat_dim.param.empty() && concat_dim.dim == 0) |
| 123 | + concat_dim = inputs[i][iaxis]; |
| 124 | + else if (inputs[i][iaxis].isParam || concat_dim.isParam) { |
| 125 | + concat_dim = |
| 126 | + Dim{ concat_dim.GetVal() + std::string("+ ") + inputs[i][iaxis].GetVal(), |
| 127 | + static_cast<size_t>(-1)}; |
126 | 128 | } else { |
127 | | - i_concat_dim += inputs[i][iaxis].dim; |
128 | | - concat_dim = std::to_string(i_concat_dim); |
| 129 | + concat_dim = Dim { concat_dim.dim + inputs[i][iaxis].dim }; |
129 | 130 | } |
130 | 131 | } |
131 | | - // other dimensions must be the same |
132 | | - else if (i > 0 && inputs[i][iaxis].GetVal() != inputs[i - 1][iaxis].GetVal()) |
| 132 | + else if (i == 0) { |
| 133 | + ret[iaxis] = inputs[i][iaxis]; |
| 134 | + } |
| 135 | + else if ((!inputs[i][iaxis].isParam && !ret[iaxis].isParam) && (inputs[i][iaxis].dim != ret[iaxis].dim)) { |
133 | 136 | throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " + |
134 | 137 | ConvertShapeToString(inputs[i]) + " and " + |
135 | 138 | ConvertShapeToString(inputs[i - 1])); |
| 139 | + } |
| 140 | + else if (!inputs[i][iaxis].isParam && ret[iaxis].isParam){ |
| 141 | + // if shape is not parametric use it |
| 142 | + ret[iaxis] = inputs[i][iaxis]; |
| 143 | + } |
| 144 | + else if (inputs[i][iaxis].isParam && ret[iaxis].isParam) { |
| 145 | + // check which parameter is first in RModel list |
| 146 | + auto & dimNames = model.GetDimShapeNames(); |
| 147 | + auto p1 = std::find(dimNames.begin(), dimNames.end(), inputs[i][iaxis].param); |
| 148 | + auto p2 = std::find(dimNames.begin(), dimNames.end(), ret[iaxis].param); |
| 149 | + if (p1 < p2) ret[iaxis] = inputs[i][iaxis]; |
| 150 | + } |
| 151 | + |
136 | 152 | } |
137 | 153 | } |
138 | 154 |
|
139 | | - // output shape |
140 | | - ret[0] = inputs[0]; |
141 | | - // check if concat_dim is an integer |
142 | | - // case like "2+n" can be converted to an integer so need to check the length |
143 | | - size_t pos = 0; |
144 | | - try { |
145 | | - i_concat_dim = std::stoi(concat_dim, &pos); |
146 | | - if (pos == concat_dim.length()) |
147 | | - ret[0][fAxis] = Dim{i_concat_dim}; // dimension is integer |
148 | | - else |
149 | | - ret[0][fAxis] = Dim{concat_dim}; |
150 | | - } |
151 | | - catch (std::invalid_argument const& ex) { |
152 | | - ret[0][fAxis] = Dim{concat_dim}; |
153 | | - } |
| 155 | + // output shape for concatenated axis |
| 156 | + ret[fAxis] = Dim{concat_dim}; |
| 157 | + // //ret[0] = inputs[0]; |
| 158 | + // // check if concat_dim is an integer |
| 159 | + // // case like "2+n" can be converted to an integer so need to check the length |
| 160 | + // size_t pos = 0; |
| 161 | + // try { |
| 162 | + // i_concat_dim = std::stoi(concat_dim, &pos); |
| 163 | + // if (pos == concat_dim.length()) |
| 164 | + // ret[fAxis] = Dim{i_concat_dim}; // dimension is integer |
| 165 | + // else { |
| 166 | + // // check if a composite expression |
| 167 | + // ret[fAxis] = Dim{concat_dim}; |
| 168 | + // } |
| 169 | + // catch (std::invalid_argument const& ex) { |
| 170 | + |
| 171 | + // } |
154 | 172 |
|
155 | 173 | } |
156 | 174 | // case of stacking (not supported yet) |
|
170 | 188 | } |
171 | 189 | fInputShapes.push_back(model.GetDimTensorShape(it)); |
172 | 190 | } |
173 | | - fOutputShape = ShapeInference(fInputShapes)[0]; |
| 191 | + fOutputShape = ShapeInference(fInputShapes, model); |
174 | 192 | if (model.Verbose()) |
175 | 193 | std::cout << "Output of concat operator has shape " << ConvertDimShapeToString(fOutputShape) << std::endl; |
176 | 194 |
|
|
0 commit comments