1919
2020struct Stats {
2121 std::vector<float > values;
22+ std::vector<int > counts;
2223 int ncall = 0 ;
2324};
2425
@@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
121122 auto & e = m_stats[wname];
122123
123124 ++e.ncall ;
124- // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
125- // using the following line, we can correct for that if needed by replacing the line above with:
126- // if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
127125
128126 if (e.values .empty ()) {
129127 e.values .resize (src1->ne [0 ]*n_as, 0 );
128+ e.counts .resize (src1->ne [0 ]*n_as, 0 );
130129 }
131130 else if (e.values .size () != (size_t )src1->ne [0 ]*n_as) {
132131 fprintf (stderr, " Oops: inconsistent size for %s (%d vs %d)\n " , wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]*n_as);
@@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
153152
154153 for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
155154 e.values [e_start + j] += x[j]*x[j];
155+ e.counts [e_start + j]++;
156156 }
157157 }
158158 }
@@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
170170 auto & e = m_stats[wname];
171171 if (e.values .empty ()) {
172172 e.values .resize (src1->ne [0 ], 0 );
173+ e.counts .resize (src1->ne [0 ], 0 );
173174 }
174175 else if (e.values .size () != (size_t )src1->ne [0 ]) {
175176 fprintf (stderr, " Oops: inconsistent size for %s (%d vs %d)\n " , wname.c_str (), (int )e.values .size (), (int )src1->ne [0 ]);
@@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
183184 const float * x = data + row * src1->ne [0 ];
184185 for (int j = 0 ; j < (int )src1->ne [0 ]; ++j) {
185186 e.values [j] += x[j]*x[j];
187+ e.counts [j]++;
186188 }
187189 }
188190 if (e.ncall > m_last_call) {
@@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
222224 out.write ((const char *) &p.second .ncall , sizeof (p.second .ncall ));
223225 int nval = p.second .values .size ();
224226 out.write ((const char *) &nval, sizeof (nval));
225- if (nval > 0 ) out.write ((const char *) p.second .values .data (), nval * sizeof (float ));
227+ if (nval > 0 ) {
228+ std::vector<float > tmp (nval);
229+ for (int i = 0 ; i < nval; i++) {
230+ tmp[i] = (p.second .values [i] / static_cast <float >(p.second .counts [i])) * static_cast <float >(p.second .ncall );
231+ }
232+ out.write ((const char *)tmp.data (), nval*sizeof (float ));
233+ }
226234 }
227235
228236 // Write the number of call the matrix was computed with
@@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma
270278 imatrix_data = {};
271279 return false ;
272280 }
273- e.values .resize (nval);
274- in.read ((char *)e.values .data (), nval*sizeof (float ));
281+
282+ // When re-called from load_imatrix() with add set, this will already be created.
283+ if (e.values .empty ()) {
284+ e.values .resize (nval, 0 );
285+ e.counts .resize (nval, 0 );
286+ }
287+
288+ std::vector<float > tmp (nval);
289+ in.read ((char *)tmp.data (), nval*sizeof (float ));
275290 if (in.fail ()) {
276291 printf (" %s: failed reading data for entry %d\n " ,__func__,i);
277292 imatrix_data = {};
278293 return false ;
279294 }
280- e.ncall = ncall;
295+
296+ // Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
297+ for (int i = 0 ; i < nval; i++) {
298+ e.values [i] += tmp[i];
299+ e.counts [i] += ncall;
300+ }
301+ e.ncall += ncall;
302+
281303 }
282304 return true ;
283305}
0 commit comments