Skip to content

Commit 0dc1ad2

Browse files
committed
clean up comments, add copy to prevent modification by reference
1 parent 8d83b7f commit 0dc1ad2

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

pyest/gm/gm.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
from numpy import pi, sqrt
9+
from copy import copy
910
from numpy.random import rand, randn
1011
from scipy.stats._multivariate import _LOG_2PI, _PSD, _squeeze_output
1112
from scipy.stats import Covariance, _covariance
@@ -880,22 +881,19 @@ def cov(self):
880881
"""
881882
# Compute the mean of the distribution
882883
mean = self.mean()
883-
884-
# Extract weights for each sample (assuming self.w exists)
885-
w = self.w
886-
887-
# Normalise weights to sum to 1
884+
885+
w = copy(self.w)
888886
w = w / np.sum(w)
889-
890-
# Extract individual covariances (self.P should be shape (n_samples, nx, nx))
887+
888+
# Extract individual covariances
891889
covs = self.P
892-
893-
# Compute the difference between each sample and the mean
894-
diffs = self.m - mean # shape (n_samples, nx)
895-
890+
891+
# Compute the difference between each mixand mean and the distribution mean
892+
diffs = self.m - mean # shape (nC, nx)
893+
896894
# Compute the outer product of diffs for each sample
897-
outer_diffs = diffs[:, :, None] * diffs[:, None, :] # shape (n_samples, nx, nx)
898-
895+
outer_diffs = diffs[:, :, None] * diffs[:, None, :] # shape (nC, nx, nx)
896+
899897
# Weighted sum of covariances and outer products
900898
return np.sum(w[:, None, None] * (covs + outer_diffs), axis=0)
901899

tests/test_gm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def test_cov():
572572
-3.95911227e+05, -1.80422801e+05, -8.03243134e+04,
573573
5.48813474e-01, -1.02480997e+00, -5.93612009e-01
574574
])
575-
575+
576576
# Single-component GMM (weight 1)
577577
weights = np.array([1])
578578
p0 = gm.GaussianMixture(weights, m0, P0)
@@ -584,7 +584,7 @@ def test_cov():
584584
recurse_depth=3, # maximum recursion depth
585585
min_weight=-np.inf
586586
)
587-
587+
588588
# No weight threshold
589589
split_tol = -np.inf
590590

0 commit comments

Comments
 (0)