3434
3535using namespace std ;
3636bool verbose = false ;
37+ bool testOutput = true ;
38+
39+
3740template <class S >
3841void BM_SOFIE_Inference (benchmark::State &state)
3942{
4043 size_t inputSize = state.range (0 ); // input size (without batch size)
41- size_t bsize = (state.range (1 ) > 0 ) ? state.range (1 ) : 0 ;
44+ size_t bsize = (state.range (1 ) > 0 ) ? state.range (1 ) : 1 ;
4245 size_t nevts = 64 ;
4346 size_t nrep = nevts / bsize;
4447
4548 vector<float > input (inputSize*nevts);
4649
47- static std::uniform_real_distribution<float > distribution (-1 , 1 );
48- static std::default_random_engine generator;
49- std::generate (input.begin (), input.end (), []() { return distribution (generator); });
50-
50+ if (testOutput) {
51+ input = std::vector<float >(input.size (),1 .);
52+ }
53+ else {
54+ static std::uniform_real_distribution<float > distribution (-1 , 1 );
55+ static std::default_random_engine generator;
56+ std::generate (input.begin (), input.end (), []() { return distribution (generator); });
57+ }
5158 float *input_ptr = input.data ();
5259 S s (" " );
5360
5461 double totDuration = 0 ;
5562 int ntimes = 0 ;
63+ std::vector<float > yOut;
64+ bool first = true ;
65+ bool doWrite = testOutput;
5666 for (auto _ : state) {
5767 auto t1 = std::chrono::high_resolution_clock::now ();
58- for (int i = 0 ; i < nevts; i += bsize)
68+ for (int i = 0 ; i < nevts; i += bsize) {
5969 auto y = s.infer (input.data ()+ inputSize*i);
60-
70+ if (first) {
71+ // std::cout << std::string(typeid(s).name()) << " : " << y[0] << " " << y[1] << std::endl;
72+ yOut = y;
73+ first = false ;
74+ }
75+ }
6176 auto t2 = std::chrono::high_resolution_clock::now ();
6277 auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count ();
6378 totDuration += duration / 1 .E3 ; // in milliseconds
6479 ntimes++;
80+ if (doWrite) {
81+ // write output for test
82+ // std::cout << "write output " << std::endl;
83+ std::ofstream f;
84+ std::string filename = std::string (typeid (s).name ()) + " .out" ;
85+ f.open (filename);
86+ f << yOut.size () << std::endl;
87+ for (size_t i = 0 ; i < yOut.size (); i++)
88+ f << yOut[i] << " " ;
89+ f << std::endl;
90+ f.close ();
91+ doWrite = false ;
92+ }
6593 }
6694
6795 state.counters [" time/evt(ms)" ] = totDuration / double (ntimes * nevts);
@@ -95,11 +123,19 @@ void BM_SOFIE_Inference_3(benchmark::State &state)
95123 vector<float > input2 (inputSize2*nevts);
96124 vector<float > input3 (inputSize3*nevts);
97125
126+ if (!testOutput) {
98127 static std::uniform_real_distribution<float > distribution (-1 , 1 );
99128 static std::default_random_engine generator;
100129 std::generate (input1.begin (), input1.end (), []() { return distribution (generator); });
101130 std::generate (input2.begin (), input2.end (), []() { return distribution (generator); });
102131 std::generate (input3.begin (), input3.end (), []() { return distribution (generator); });
132+ }
133+ else {
134+ // generate fixed data
135+ input1 = vector<float >(input1.size (),1 .);
136+ input2 = vector<float >(input2.size (),2 .);
137+ input3 = vector<float >(input3.size (),3 .);
138+ }
103139
104140 S s (" " );
105141
@@ -125,9 +161,9 @@ void BM_SOFIE_Inference_3(benchmark::State &state)
125161}
126162
127163// CMS benchmark (3 inputs)
128- BENCHMARK_TEMPLATE (BM_SOFIE_Inference_3, TMVA_SOFIE_DDB_B1::Session)->Name(" DDB_B1" )->Args({1 , 1 *27 , 60 *8 , 5 *2 })->Unit(benchmark::kMillisecond );
164+ // BENCHMARK_TEMPLATE(BM_SOFIE_Inference_3, TMVA_SOFIE_DDB_B1::Session)->Name("DDB_B1")->Args({1, 1*27, 60*8, 5*2})->Unit(benchmark::kMillisecond);
129165// Conv Transpose
130- BENCHMARK_TEMPLATE (BM_SOFIE_Inference, TMVA_SOFIE_Conv2DTranspose_Relu_Sigmoid::Session)->Name(" Cov2DTranspose_B1 " )->Args({1 , 1 * 15 })->Unit(benchmark::kMillisecond );
166+ BENCHMARK_TEMPLATE (BM_SOFIE_Inference, TMVA_SOFIE_Conv2DTranspose_Relu_Sigmoid::Session)->Name(" Cov2DTranspose_Relu_Sigmoid " )->Args({15 , 1 })->Unit(benchmark::kMillisecond );
131167
132168// Gemm benchmarks
133169BENCHMARK_TEMPLATE (BM_SOFIE_Inference, TMVA_SOFIE_Linear_16::Session)->Name(" Linear_16" )->Args({100 , 16 })->Unit(benchmark::kMillisecond );
0 commit comments