Skip to content

Commit 1f99c18

Browse files
authored
Merge pull request #1 from stuart-lab/zeroes
Fix bug in Pearson residual variance calculation
2 parents b442db0 + a303400 commit 1f99c18

File tree

4 files changed

+121
-8
lines changed

4 files changed

+121
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
Cargo.lock
3+
.DS_Store

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 1.4.0
2+
3+
Fix bug in computation of Pearson residual variance that did not correctly account for zeros
4+
15
# 1.3.0
26

37
Update Pearson residual clipping to be sqrt(N)

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[package]
22
name = "spars"
3-
version = "1.3.0"
3+
version = "1.4.0"
44
edition = "2021"
5-
authors = ["Tim Stuart <stuartt@gis.a-star.edu.sg>"]
5+
authors = ["Tim Stuart <stuartt@a-star.edu.sg>"]
66
description = "💥 Disk-based sparse matrix statistics and subsetting 💥"
77
license = "MIT"
88
repository = "https://github.com/stuart-lab/spars"

src/stats.rs

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,29 @@ pub fn compute_stats(input_file: &str, output_prefix: &str, sort_by: Option<Stri
271271

272272
println!(""); // Clear the progress line
273273

274+
// Add contributions from implicit zeros to Pearson residual variance
275+
for (row_idx, stats) in &row_stats {
276+
let num_zeros = stats.count - stats.nonzero_count;
277+
if num_zeros > 0 {
278+
let mean = stats.mean();
279+
if mean > 0.0 { // Only process if mean > 0 (avoid division by zero)
280+
let denominator = (mean + mean * mean / theta).sqrt();
281+
let zero_residual = -mean / denominator;
282+
283+
// Clip by sqrt(n_cols) as done for non-zeros
284+
let clip_threshold = (n_cols as f64).sqrt();
285+
let clipped_residual = zero_residual.max(-clip_threshold).min(clip_threshold);
286+
287+
// Add contribution of all zeros to the sums
288+
let zero_contribution = num_zeros as f64 * clipped_residual;
289+
let zero_squares_contribution = num_zeros as f64 * clipped_residual * clipped_residual;
290+
291+
*row_pearson_residual_sums.entry(*row_idx).or_insert(0.0) += zero_contribution;
292+
*row_pearson_residual_squares.entry(*row_idx).or_insert(0.0) += zero_squares_contribution;
293+
}
294+
}
295+
}
296+
274297
// Calculate residual variances for each row
275298
for (row_idx, stats) in &mut row_stats {
276299
// Calculate Pearson residual variance
@@ -413,15 +436,27 @@ impl Stats {
413436
}
414437
}
415438

416-
fn finalize(&mut self, total_count: usize, nonzero_count: usize) {
417-
// Adjust the count and nonzero_count
418-
self.count = total_count;
419-
self.nonzero_count = nonzero_count;
439+
// fn finalize(&mut self, total_count: usize, nonzero_count: usize) {
440+
// // Adjust the count and nonzero_count
441+
// self.count = total_count;
442+
// self.nonzero_count = nonzero_count;
443+
444+
// // No need to adjust sum or sum_of_squares as zeros don't contribute
420445

421-
// No need to adjust sum or sum_of_squares as zeros don't contribute
446+
// // min is nonzero minumum
447+
// }
448+
449+
fn finalize(&mut self, total_count: usize, nonzero_count: usize) {
450+
// Adjust the count and nonzero_count
451+
self.count = total_count;
452+
self.nonzero_count = nonzero_count;
422453

423-
// min is nonzero minumum
454+
// If there are any zeros (implicit), min should be 0
455+
if nonzero_count < total_count {
456+
self.min = 0.0;
424457
}
458+
// Note: max stays as is - the maximum non-zero value is correct
459+
}
425460

426461
fn mean(&self) -> f64 {
427462
self.sum / self.count as f64
@@ -436,4 +471,77 @@ impl Stats {
436471
fn std_dev(&self) -> f64 {
437472
self.variance().sqrt()
438473
}
474+
}
475+
476+
#[cfg(test)]
477+
mod tests {
478+
use super::*;
479+
use std::io::Write;
480+
use tempfile::NamedTempFile;
481+
482+
#[test]
483+
fn test_stats_handles_implicit_zeros() {
484+
// Create a Stats struct with explicit values only
485+
let mut stats = Stats::new(5.0);
486+
stats.update(10.0);
487+
488+
// Before finalize: only tracked 2 values
489+
assert_eq!(stats.count, 2);
490+
assert_eq!(stats.min, 5.0); // Wrong! Missing zeros
491+
492+
// After finalize with 10 total elements (8 implicit zeros)
493+
stats.finalize(10, 2);
494+
495+
assert_eq!(stats.count, 10);
496+
assert_eq!(stats.nonzero_count, 2);
497+
assert_eq!(stats.min, 0.0); // Should now be 0
498+
assert_eq!(stats.max, 10.0); // Max unchanged
499+
500+
// Mean should be (5+10)/10 = 1.5
501+
assert!((stats.mean() - 1.5).abs() < 1e-10);
502+
}
503+
504+
#[test]
505+
fn test_compute_stats_small_matrix() {
506+
// Create a small test matrix
507+
let mut temp_file = NamedTempFile::new().unwrap();
508+
writeln!(temp_file,
509+
"%%MatrixMarket matrix coordinate real general\n\
510+
3 4 3\n\
511+
1 1 6.0\n\
512+
1 2 6.0\n\
513+
2 3 12.0").unwrap();
514+
515+
let temp_path = temp_file.path().to_str().unwrap();
516+
517+
// Run stats computation
518+
compute_stats(temp_path, "test_output", Some("PearsonResidualVar".to_string()), Some(100.0))
519+
.expect("Stats computation failed");
520+
521+
// Check the output files exist
522+
assert!(std::path::Path::new("test_output_row.tsv").exists());
523+
assert!(std::path::Path::new("test_output_col.tsv").exists());
524+
525+
// Clean up
526+
std::fs::remove_file("test_output_row.tsv").ok();
527+
std::fs::remove_file("test_output_col.tsv").ok();
528+
}
529+
530+
#[test]
531+
fn test_pearson_residual_calculation() {
532+
// Test the math for Pearson residuals with zeros
533+
let mean: f64 = 1.2;
534+
let theta: f64 = 100.0;
535+
let n_cols: f64 = 10.0; // Make this f64 too
536+
537+
let denominator = (mean + mean * mean / theta).sqrt();
538+
let zero_residual = -mean / denominator;
539+
let clip_threshold = n_cols.sqrt(); // Now works since n_cols is f64
540+
541+
// Zero residual should be negative
542+
assert!(zero_residual < 0.0);
543+
544+
// Should be within clipping threshold
545+
assert!(zero_residual.abs() <= clip_threshold);
546+
}
439547
}

0 commit comments

Comments
 (0)