Skip to content

Commit f041575

Browse files
committed
[tmva][sofie] Implement LogSoftMax as a variant of SoftMax
Add support for LogSoftMax operator. Re-use the iperator definition of SOftMax, just adding a flag to distinguish the log case.
1 parent 1b58a9a commit f041575

File tree

3 files changed

+28
-122
lines changed

3 files changed

+28
-122
lines changed

tmva/sofie/inc/TMVA/ROperator_Softmax.hxx

Lines changed: 21 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ namespace TMVA {
1111
namespace Experimental {
1212
namespace SOFIE {
1313

14-
template <typename T>
14+
// implement Softmax and LogSoftmax
1515
class ROperator_Softmax final : public ROperator {
1616

1717
private:
18+
bool fLogSoftmax; // for the logsoftmax case
1819
int64_t fAttrAxis;
1920

2021
std::string fNX;
@@ -25,8 +26,10 @@ private:
2526

2627
public:
2728
ROperator_Softmax() {}
28-
ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
29-
: fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
29+
ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY, bool logSoftmax)
30+
: fLogSoftmax(logSoftmax),
31+
fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
32+
3033
{
3134
fInputTensorNames = { fNX };
3235
fOutputTensorNames = { fNY };
@@ -75,7 +78,10 @@ public:
7578
out << SP << SP << "sum += tensor_" << fNY << "[i];\n";
7679
out << SP << "}\n";
7780
out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
78-
out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
81+
if (fLogSoftmax)
82+
out << SP << SP << "tensor_" << fNY << "[i] = std::log(tensor_" << fNY << "[i] / sum );\n";
83+
else
84+
out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
7985
out << SP << "}\n";
8086
} else {
8187
int k = 0;
@@ -129,125 +135,24 @@ public:
129135
// normalize
130136
for (int j = 0; j < size-1; j++) out << SP;
131137
out << "for (int i = 0; i < " << fShape[axis] << "; i++) {\n";
132-
for (int j = 0; j < size; j++) out << SP;
133-
out << "tensor_" << fNY << "[index + i";
134-
if (stride[axis].GetVal() != "1") out << "*(" << stride[axis] << ")";
135-
out << "] /= sum;\n";
138+
for (int j = 0; j < size; j++) out << SP;
139+
// define the tensor y value for given i index
140+
std::string tensor_y_i = "tensor_" + fNY + "[index + i";
141+
if (stride[axis].GetVal() != "1")
142+
tensor_y_i += "*(" + stride[axis].GetVal() + ")";
143+
tensor_y_i += "]";
144+
if (fLogSoftmax) {
145+
out << tensor_y_i << " = std::log(" << tensor_y_i << " / sum);\n";
146+
} else {
147+
out << tensor_y_i << " /= sum;\n";
148+
}
136149
for (int j = 0; j < size-1; j++) out << SP;
137150
out << "}\n";
138151
//end loops
139152
for (int i = size-2; i >=0; i--) {
140153
for (int j = 0; j < i; j++) out << SP;
141154
out << "}\n";
142155
}
143-
144-
#if 0
145-
size_t batch = fShape[0];
146-
size_t channel = fShape[1];
147-
size_t width = (size > 2) ? fShape[size - 1] : 1;
148-
size_t height = (size > 3) ? fShape[size - 2] : 1;
149-
size_t depth = (size > 4) ? fShape[size - 3] : 1;
150-
size_t hStride = width;
151-
size_t dStride = height * width;
152-
size_t cStride = depth * dStride;
153-
size_t bStride = channel * cStride;
154-
155-
size_t N = 0; // Size of the axis
156-
size_t iStride = 0;
157-
if (axis == 0) {
158-
N = batch;
159-
iStride = bStride;
160-
} else if (axis == 1) {
161-
N = channel;
162-
iStride = cStride;
163-
} else if (axis == size - 1) {
164-
N = width;
165-
iStride = 1;
166-
} else if (size > 3 && axis == size - 2) {
167-
N = height;
168-
iStride = hStride;
169-
} else if (size == 5 && axis == size - 3) {
170-
N = depth;
171-
iStride = dStride;
172-
} else {
173-
throw
174-
std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
175-
+ std::to_string(fAttrAxis) + " with " + std::to_string(size)
176-
+ "d input tensor not supported.");
177-
}
178-
179-
bool notBatch = axis != 0;
180-
bool notChannel = axis != 1;
181-
bool notDepth = (size == 5 && axis != 2);
182-
bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
183-
bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
184-
185-
if (notBatch) {
186-
out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
187-
}
188-
if (notChannel) {
189-
out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
190-
}
191-
if (notDepth) {
192-
out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
193-
}
194-
if (notHeight) {
195-
out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
196-
}
197-
if (notWidth) {
198-
out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
199-
}
200-
out << SP << SP << SP << fType << " sum = 0.;\n";
201-
out << SP << SP << SP << "size_t index = 0";
202-
if (notBatch) {
203-
out << " + n * " << bStride;
204-
}
205-
if (notChannel) {
206-
out << "+ c * " << cStride;
207-
}
208-
if (notDepth) {
209-
out << " + d * " << dStride;
210-
}
211-
if (notHeight) {
212-
out << " + h * " << hStride;
213-
}
214-
if (notWidth) {
215-
out << " + w";
216-
}
217-
out << ";\n";
218-
// apply softmax along the axis - find first maximum value for numerical stability
219-
if (N == 0)
220-
throw std::runtime_error("TMVA::SOFIE - Softmax operator is along axis with zero elements");
221-
out << SP << SP << SP << fType << " vmax = tensor_" << fNX << "[index];\n";
222-
out << SP << SP << SP << "for (size_t i = 1; i < " << N << "; i++) {\n";
223-
out << SP << SP << SP << SP << "if (tensor_" << fNX << "[index + i*" << iStride << "] > vmax)\n";
224-
out << SP << SP << SP << SP << SP << "vmax = tensor_" << fNX << "[index + i*" << iStride << "];\n";
225-
out << SP << SP << SP << "}\n";
226-
out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
227-
out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
228-
<< "[index + i*" << iStride << "] - vmax);\n";
229-
out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
230-
out << SP << SP << SP << "}\n";
231-
out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
232-
out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
233-
out << SP << SP << SP << "}\n";
234-
if (notWidth) {
235-
out << SP << SP << "}\n"; // end w
236-
}
237-
if (notHeight) {
238-
out << SP << SP << "}\n"; // end h
239-
}
240-
if (notDepth) {
241-
out << SP << SP << "}\n"; // end d
242-
}
243-
if (notChannel) {
244-
out << SP << SP << "}\n"; // end c
245-
}
246-
if (notBatch) {
247-
out << SP << "}\n"; // end n
248-
}
249-
250-
#endif
251156
}
252157
return out.str();
253158
}

tmva/sofie_parsers/src/ParseSoftmax.cxx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@ ParserFuncSignature ParseSoftmax = [](RModelParser_ONNX &parser, const onnx::Nod
2020
std::unique_ptr<ROperator> op;
2121
std::string output_name = nodeproto.output(0);
2222

23+
bool logSoftmax = (nodeproto.op_type() == "LogSoftmax");
24+
2325
int64_t attr_axis = -1;
2426
if (nodeproto.attribute_size() == 1 && nodeproto.attribute(0).name() == "axis")
2527
attr_axis = nodeproto.attribute(0).i();
2628

27-
switch (input_type) {
28-
case ETensorType::FLOAT: op.reset(new ROperator_Softmax<float>(attr_axis, input_name, output_name)); break;
29-
default:
30-
throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Softmax does not yet support input type " +
31-
std::to_string(static_cast<int>(input_type)));
32-
}
29+
op.reset(new ROperator_Softmax(attr_axis, input_name, output_name, logSoftmax));
30+
3331

3432
if (!parser.IsRegisteredTensorType(output_name)) {
3533
parser.RegisterTensorType(output_name, input_type);
3634
}
3735
return op;
3836
};
3937

38+
39+
4040
} // namespace SOFIE
4141
} // namespace Experimental
4242
} // namespace TMVA

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
214214
RegisterOperator("Sigmoid", ParseSigmoid);
215215
RegisterOperator("Slice", ParseSlice);
216216
RegisterOperator("Softmax", ParseSoftmax);
217+
RegisterOperator("LogSoftmax", ParseSoftmax);
217218
RegisterOperator("Tanh", ParseTanh);
218219
RegisterOperator("Transpose", ParseTranspose);
219220
RegisterOperator("MatMul", ParseMatMul);

0 commit comments

Comments
 (0)