@@ -271,6 +271,29 @@ pub fn compute_stats(input_file: &str, output_prefix: &str, sort_by: Option<Stri
271271
272272 println ! ( "" ) ; // Clear the progress line
273273
274+ // Add contributions from implicit zeros to Pearson residual variance
275+ for ( row_idx, stats) in & row_stats {
276+ let num_zeros = stats. count - stats. nonzero_count ;
277+ if num_zeros > 0 {
278+ let mean = stats. mean ( ) ;
279+ if mean > 0.0 { // Only process if mean > 0 (avoid division by zero)
280+ let denominator = ( mean + mean * mean / theta) . sqrt ( ) ;
281+ let zero_residual = -mean / denominator;
282+
283+ // Clip by sqrt(n_cols) as done for non-zeros
284+ let clip_threshold = ( n_cols as f64 ) . sqrt ( ) ;
285+ let clipped_residual = zero_residual. max ( -clip_threshold) . min ( clip_threshold) ;
286+
287+ // Add contribution of all zeros to the sums
288+ let zero_contribution = num_zeros as f64 * clipped_residual;
289+ let zero_squares_contribution = num_zeros as f64 * clipped_residual * clipped_residual;
290+
291+ * row_pearson_residual_sums. entry ( * row_idx) . or_insert ( 0.0 ) += zero_contribution;
292+ * row_pearson_residual_squares. entry ( * row_idx) . or_insert ( 0.0 ) += zero_squares_contribution;
293+ }
294+ }
295+ }
296+
274297 // Calculate residual variances for each row
275298 for ( row_idx, stats) in & mut row_stats {
276299 // Calculate Pearson residual variance
@@ -413,15 +436,27 @@ impl Stats {
413436 }
414437 }
415438
416- fn finalize ( & mut self , total_count : usize , nonzero_count : usize ) {
417- // Adjust the count and nonzero_count
418- self . count = total_count;
419- self . nonzero_count = nonzero_count;
439+ // fn finalize(&mut self, total_count: usize, nonzero_count: usize) {
440+ // // Adjust the count and nonzero_count
441+ // self.count = total_count;
442+ // self.nonzero_count = nonzero_count;
443+
444+ // // No need to adjust sum or sum_of_squares as zeros don't contribute
420445
421- // No need to adjust sum or sum_of_squares as zeros don't contribute
446+ // // min is nonzero minumum
447+ // }
448+
449+ fn finalize ( & mut self , total_count : usize , nonzero_count : usize ) {
450+ // Adjust the count and nonzero_count
451+ self . count = total_count;
452+ self . nonzero_count = nonzero_count;
422453
423- // min is nonzero minumum
454+ // If there are any zeros (implicit), min should be 0
455+ if nonzero_count < total_count {
456+ self . min = 0.0 ;
424457 }
458+ // Note: max stays as is - the maximum non-zero value is correct
459+ }
425460
426461 fn mean ( & self ) -> f64 {
427462 self . sum / self . count as f64
@@ -436,4 +471,77 @@ impl Stats {
436471 fn std_dev ( & self ) -> f64 {
437472 self . variance ( ) . sqrt ( )
438473 }
474+ }
475+
476+ #[ cfg( test) ]
477+ mod tests {
478+ use super :: * ;
479+ use std:: io:: Write ;
480+ use tempfile:: NamedTempFile ;
481+
482+ #[ test]
483+ fn test_stats_handles_implicit_zeros ( ) {
484+ // Create a Stats struct with explicit values only
485+ let mut stats = Stats :: new ( 5.0 ) ;
486+ stats. update ( 10.0 ) ;
487+
488+ // Before finalize: only tracked 2 values
489+ assert_eq ! ( stats. count, 2 ) ;
490+ assert_eq ! ( stats. min, 5.0 ) ; // Wrong! Missing zeros
491+
492+ // After finalize with 10 total elements (8 implicit zeros)
493+ stats. finalize ( 10 , 2 ) ;
494+
495+ assert_eq ! ( stats. count, 10 ) ;
496+ assert_eq ! ( stats. nonzero_count, 2 ) ;
497+ assert_eq ! ( stats. min, 0.0 ) ; // Should now be 0
498+ assert_eq ! ( stats. max, 10.0 ) ; // Max unchanged
499+
500+ // Mean should be (5+10)/10 = 1.5
501+ assert ! ( ( stats. mean( ) - 1.5 ) . abs( ) < 1e-10 ) ;
502+ }
503+
504+ #[ test]
505+ fn test_compute_stats_small_matrix ( ) {
506+ // Create a small test matrix
507+ let mut temp_file = NamedTempFile :: new ( ) . unwrap ( ) ;
508+ writeln ! ( temp_file,
509+ "%%MatrixMarket matrix coordinate real general\n \
510+ 3 4 3\n \
511+ 1 1 6.0\n \
512+ 1 2 6.0\n \
513+ 2 3 12.0") . unwrap ( ) ;
514+
515+ let temp_path = temp_file. path ( ) . to_str ( ) . unwrap ( ) ;
516+
517+ // Run stats computation
518+ compute_stats ( temp_path, "test_output" , Some ( "PearsonResidualVar" . to_string ( ) ) , Some ( 100.0 ) )
519+ . expect ( "Stats computation failed" ) ;
520+
521+ // Check the output files exist
522+ assert ! ( std:: path:: Path :: new( "test_output_row.tsv" ) . exists( ) ) ;
523+ assert ! ( std:: path:: Path :: new( "test_output_col.tsv" ) . exists( ) ) ;
524+
525+ // Clean up
526+ std:: fs:: remove_file ( "test_output_row.tsv" ) . ok ( ) ;
527+ std:: fs:: remove_file ( "test_output_col.tsv" ) . ok ( ) ;
528+ }
529+
530+ #[ test]
531+ fn test_pearson_residual_calculation ( ) {
532+ // Test the math for Pearson residuals with zeros
533+ let mean: f64 = 1.2 ;
534+ let theta: f64 = 100.0 ;
535+ let n_cols: f64 = 10.0 ; // Make this f64 too
536+
537+ let denominator = ( mean + mean * mean / theta) . sqrt ( ) ;
538+ let zero_residual = -mean / denominator;
539+ let clip_threshold = n_cols. sqrt ( ) ; // Now works since n_cols is f64
540+
541+ // Zero residual should be negative
542+ assert ! ( zero_residual < 0.0 ) ;
543+
544+ // Should be within clipping threshold
545+ assert ! ( zero_residual. abs( ) <= clip_threshold) ;
546+ }
439547}
0 commit comments