Skip to content

Commit 37f164e

Browse files
Add option to fix the first and last node in kMeansClustering (#554)
* add option to fix the first and last node in kMeansClustering * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 43a4c1d commit 37f164e

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

source/framework/core/inc/TRestVolumeHits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class TRestVolumeHits : public TRestHits {
6969
return TMath::Sqrt(fSigmaX[n] * fSigmaX[n] + fSigmaY[n] * fSigmaY[n]);
7070
}
7171

72-
static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100);
72+
static void kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt = 100,
73+
bool fixBoundaries = false);
7374

7475
// Constructor & Destructor
7576
TRestVolumeHits();

source/framework/core/src/TRestVolumeHits.cxx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,14 @@ void TRestVolumeHits::PrintHits() const {
162162
}
163163
}
164164

165-
void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt) {
165+
void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& vHits, int maxIt,
166+
bool fixBoundaries) {
166167
const int nodes = vHits.GetNumberOfHits();
167168
vector<TRestVolumeHits> volHits(nodes);
168169
// std::cout<<"Nhits "<<hits->GetNumberOfHits()<<" Nodes "<<nodes<<std::endl;
169170
TVector3 nullVector = TVector3(0, 0, 0);
170171
std::vector<TVector3> centroid(nodes);
171-
std::vector<TVector3> centroidOld(nodes, nullVector);
172+
std::vector<TVector3> centroidOld(nodes, nullVector); // used for iterations
172173

173174
for (int h = 0; h < nodes; h++) centroid[h] = vHits.GetPosition(h);
174175

@@ -178,6 +179,7 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
178179
double minDist = 1E9;
179180
int clIndex = -1;
180181
for (int n = 0; n < nodes; n++) {
182+
if (fixBoundaries && (n == 0 || n == nodes - 1)) continue; // Skip fixed nodes
181183
TVector3 hitPos = hits->GetPosition(i);
182184
double dist = (centroid[n] - hitPos).Mag();
183185
if (dist < minDist) {
@@ -188,8 +190,11 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
188190
// cout<<minDist<<" "<<clIndex<<endl;
189191
volHits[clIndex].AddHit(*hits, i);
190192
}
193+
194+
// Update centroids and check for convergence
191195
bool converge = true;
192196
for (int n = 0; n < nodes; n++) {
197+
if (fixBoundaries && (n == 0 || n == nodes - 1)) continue; // Skip fixed nodes
193198
centroid[n] = volHits[n].GetMeanPosition();
194199
converge &= (centroid[n] == centroidOld[n]);
195200
centroidOld[n] = centroid[n];
@@ -202,8 +207,12 @@ void TRestVolumeHits::kMeansClustering(TRestVolumeHits* hits, TRestVolumeHits& v
202207
vHits.RemoveHits();
203208
const TVector3 sigma(0., 0., 0.);
204209
for (int n = 0; n < nodes; n++) {
205-
if (volHits[n].GetNumberOfHits() > 0)
206-
vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0, volHits[n].GetType(0),
207-
sigma);
210+
if (fixBoundaries && (n == 0 || n == nodes - 1)) {
211+
vHits.AddHit(centroid[n], 0, 0, vHits.GetType(n), sigma);
212+
} else {
213+
if (volHits[n].GetNumberOfHits() > 0)
214+
vHits.AddHit(volHits[n].GetMeanPosition(), volHits[n].GetTotalEnergy(), 0,
215+
volHits[n].GetType(0), sigma);
216+
}
208217
}
209218
}

0 commit comments

Comments
 (0)