@@ -162,13 +162,14 @@ void TRestVolumeHits::PrintHits() const {
162162 }
163163}
164164
165- void TRestVolumeHits::kMeansClustering (TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt) {
165+ void TRestVolumeHits::kMeansClustering (TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt,
166+ bool fixBoundaries) {
166167 const int nodes = vHits.GetNumberOfHits ();
167168 vector<TRestVolumeHits> volHits (nodes);
168169 // std::cout<<"Nhits "<<hits->GetNumberOfHits()<<" Nodes "<<nodes<<std::endl;
169170 TVector3 nullVector = TVector3 (0 , 0 , 0 );
170171 std::vector<TVector3> centroid (nodes);
171- std::vector<TVector3> centroidOld (nodes, nullVector);
172+ std::vector<TVector3> centroidOld (nodes, nullVector); // used for iterations
172173
173174 for (int h = 0 ; h < nodes; h++) centroid[h] = vHits.GetPosition (h);
174175
@@ -178,6 +179,7 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
178179 double minDist = 1E9 ;
179180 int clIndex = -1 ;
180181 for (int n = 0 ; n < nodes; n++) {
182+ if (fixBoundaries && (n == 0 || n == nodes - 1 )) continue ; // Skip fixed nodes
181183 TVector3 hitPos = hits->GetPosition (i);
182184 double dist = (centroid[n] - hitPos).Mag ();
183185 if (dist < minDist) {
@@ -188,8 +190,11 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
188190 // cout<<minDist<<" "<<clIndex<<endl;
189191 volHits[clIndex].AddHit (*hits, i);
190192 }
193+
194+ // Update centroids and check for convergence
191195 bool converge = true ;
192196 for (int n = 0 ; n < nodes; n++) {
197+ if (fixBoundaries && (n == 0 || n == nodes - 1 )) continue ; // Skip fixed nodes
193198 centroid[n] = volHits[n].GetMeanPosition ();
194199 converge &= (centroid[n] == centroidOld[n]);
195200 centroidOld[n] = centroid[n];
@@ -202,8 +207,12 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
202207 vHits.RemoveHits ();
203208 const TVector3 sigma (0 ., 0 ., 0 .);
204209 for (int n = 0 ; n < nodes; n++) {
205- if (volHits[n].GetNumberOfHits () > 0 )
206- vHits.AddHit (volHits[n].GetMeanPosition (), volHits[n].GetTotalEnergy (), 0 , volHits[n].GetType (0 ),
207- sigma);
210+ if (fixBoundaries && (n == 0 || n == nodes - 1 )) {
211+ vHits.AddHit (centroid[n], 0 , 0 , vHits.GetType (n), sigma);
212+ } else {
213+ if (volHits[n].GetNumberOfHits () > 0 )
214+ vHits.AddHit (volHits[n].GetMeanPosition (), volHits[n].GetTotalEnergy (), 0 ,
215+ volHits[n].GetType (0 ), sigma);
216+ }
208217 }
209218}
0 commit comments