@@ -11,10 +11,11 @@ namespace TMVA {
1111namespace Experimental {
1212namespace SOFIE {
1313
14- template < typename T>
14+ // implement Softmax and LogSoftmax
1515class ROperator_Softmax final : public ROperator {
1616
1717private:
18+ bool fLogSoftmax ; // for the logsoftmax case
1819 int64_t fAttrAxis ;
1920
2021 std::string fNX ;
@@ -25,8 +26,10 @@ private:
2526
2627public:
2728 ROperator_Softmax () {}
28- ROperator_Softmax (int64_t attr_axis, std::string nameX, std::string nameY)
29- : fAttrAxis (attr_axis), fNX (UTILITY::Clean_name(nameX)), fNY (UTILITY::Clean_name(nameY))
29+ ROperator_Softmax (int64_t attr_axis, std::string nameX, std::string nameY, bool logSoftmax)
30+ : fLogSoftmax (logSoftmax),
31+ fAttrAxis (attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
32+
3033 {
3134 fInputTensorNames = { fNX };
3235 fOutputTensorNames = { fNY };
@@ -75,7 +78,10 @@ public:
7578 out << SP << SP << " sum += tensor_" << fNY << " [i];\n " ;
7679 out << SP << " }\n " ;
7780 out << SP << " for (size_t i = 0; i < " << length << " ; i++){\n " ;
78- out << SP << SP << " tensor_" << fNY << " [i] /= sum;\n " ;
81+ if (fLogSoftmax )
82+ out << SP << SP << " tensor_" << fNY << " [i] = std::log(tensor_" << fNY << " [i] / sum );\n " ;
83+ else
84+ out << SP << SP << " tensor_" << fNY << " [i] /= sum;\n " ;
7985 out << SP << " }\n " ;
8086 } else {
8187 int k = 0 ;
@@ -129,125 +135,24 @@ public:
129135 // normalize
130136 for (int j = 0 ; j < size-1 ; j++) out << SP;
131137 out << " for (int i = 0; i < " << fShape [axis] << " ; i++) {\n " ;
132- for (int j = 0 ; j < size; j++) out << SP;
133- out << " tensor_" << fNY << " [index + i" ;
134- if (stride[axis].GetVal () != " 1" ) out << " *(" << stride[axis] << " )" ;
135- out << " ] /= sum;\n " ;
138+ for (int j = 0 ; j < size; j++) out << SP;
139+ // define the tensor y value for given i index
140+ std::string tensor_y_i = " tensor_" + fNY + " [index + i" ;
141+ if (stride[axis].GetVal () != " 1" )
142+ tensor_y_i += " *(" + stride[axis].GetVal () + " )" ;
143+ tensor_y_i += " ]" ;
144+ if (fLogSoftmax ) {
145+ out << tensor_y_i << " = std::log(" << tensor_y_i << " / sum);\n " ;
146+ } else {
147+ out << tensor_y_i << " /= sum;\n " ;
148+ }
136149 for (int j = 0 ; j < size-1 ; j++) out << SP;
137150 out << " }\n " ;
138151 // end loops
139152 for (int i = size-2 ; i >=0 ; i--) {
140153 for (int j = 0 ; j < i; j++) out << SP;
141154 out << " }\n " ;
142155 }
143-
144- #if 0
145- size_t batch = fShape[0];
146- size_t channel = fShape[1];
147- size_t width = (size > 2) ? fShape[size - 1] : 1;
148- size_t height = (size > 3) ? fShape[size - 2] : 1;
149- size_t depth = (size > 4) ? fShape[size - 3] : 1;
150- size_t hStride = width;
151- size_t dStride = height * width;
152- size_t cStride = depth * dStride;
153- size_t bStride = channel * cStride;
154-
155- size_t N = 0; // Size of the axis
156- size_t iStride = 0;
157- if (axis == 0) {
158- N = batch;
159- iStride = bStride;
160- } else if (axis == 1) {
161- N = channel;
162- iStride = cStride;
163- } else if (axis == size - 1) {
164- N = width;
165- iStride = 1;
166- } else if (size > 3 && axis == size - 2) {
167- N = height;
168- iStride = hStride;
169- } else if (size == 5 && axis == size - 3) {
170- N = depth;
171- iStride = dStride;
172- } else {
173- throw
174- std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
175- + std::to_string(fAttrAxis) + " with " + std::to_string(size)
176- + "d input tensor not supported.");
177- }
178-
179- bool notBatch = axis != 0;
180- bool notChannel = axis != 1;
181- bool notDepth = (size == 5 && axis != 2);
182- bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
183- bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
184-
185- if (notBatch) {
186- out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
187- }
188- if (notChannel) {
189- out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
190- }
191- if (notDepth) {
192- out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
193- }
194- if (notHeight) {
195- out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
196- }
197- if (notWidth) {
198- out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
199- }
200- out << SP << SP << SP << fType << " sum = 0.;\n";
201- out << SP << SP << SP << "size_t index = 0";
202- if (notBatch) {
203- out << " + n * " << bStride;
204- }
205- if (notChannel) {
206- out << "+ c * " << cStride;
207- }
208- if (notDepth) {
209- out << " + d * " << dStride;
210- }
211- if (notHeight) {
212- out << " + h * " << hStride;
213- }
214- if (notWidth) {
215- out << " + w";
216- }
217- out << ";\n";
218- // apply softmax along the axis - find first maximum value for numerical stability
219- if (N == 0)
220- throw std::runtime_error("TMVA::SOFIE - Softmax operator is along axis with zero elements");
221- out << SP << SP << SP << fType << " vmax = tensor_" << fNX << "[index];\n";
222- out << SP << SP << SP << "for (size_t i = 1; i < " << N << "; i++) {\n";
223- out << SP << SP << SP << SP << "if (tensor_" << fNX << "[index + i*" << iStride << "] > vmax)\n";
224- out << SP << SP << SP << SP << SP << "vmax = tensor_" << fNX << "[index + i*" << iStride << "];\n";
225- out << SP << SP << SP << "}\n";
226- out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
227- out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
228- << "[index + i*" << iStride << "] - vmax);\n";
229- out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
230- out << SP << SP << SP << "}\n";
231- out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
232- out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
233- out << SP << SP << SP << "}\n";
234- if (notWidth) {
235- out << SP << SP << "}\n"; // end w
236- }
237- if (notHeight) {
238- out << SP << SP << "}\n"; // end h
239- }
240- if (notDepth) {
241- out << SP << SP << "}\n"; // end d
242- }
243- if (notChannel) {
244- out << SP << SP << "}\n"; // end c
245- }
246- if (notBatch) {
247- out << SP << "}\n"; // end n
248- }
249-
250- #endif
251156 }
252157 return out.str ();
253158 }
0 commit comments