Skip to content

Commit d021b20

Browse files
authored
fix(numeric): Fix PgNumeric conversion (#341)
1 parent 96b1210 commit d021b20

File tree

1 file changed

+247
-7
lines changed

1 file changed

+247
-7
lines changed

etl/src/conversions/numeric.rs

Lines changed: 247 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use tokio_postgres::types::{FromSql, IsNull, ToSql, Type};
88

99
const POSITIVE_SIGN: u16 = 0x0000;
1010
const NEGATIVE_SIGN: u16 = 0x4000;
11-
const NAN_SIGN: u16 = 0xC000;
12-
const POSITIVE_INFINITY_SIGN: u16 = 0xC000;
13-
const NEGATIVE_INFINITY_SIGN: u16 = 0xF000;
11+
const NAN_SIGN: u16 = 0xC000; // NUMERIC_NAN
12+
const POSITIVE_INFINITY_SIGN: u16 = 0xD000; // NUMERIC_PINF
13+
const NEGATIVE_INFINITY_SIGN: u16 = 0xF000; // NUMERIC_NINF
1414

1515
/// Sign indicator for Postgres numeric values.
1616
///
@@ -462,13 +462,28 @@ fn convert_to_base_10000(
462462
base_10000_digits.push(digit);
463463
}
464464

465-
// Strip leading and trailing zeros
465+
// If all groups are zero, normalize to PostgreSQL canonical zero:
466+
// weight = 0, digits = [], positive sign; preserve scale.
467+
if base_10000_digits.iter().all(|&d| d == 0) {
468+
return Ok(PgNumeric::Value {
469+
sign: Sign::Positive,
470+
weight: 0,
471+
scale: dscale,
472+
digits: vec![],
473+
});
474+
}
475+
476+
// Strip leading zeros first and record how many we removed so we can
477+
// adjust the weight correctly. Trailing zeros should NOT influence
478+
// the weight because they are fractional groups after the decimal point.
479+
let leading_zeros_before_strip =
480+
base_10000_digits.iter().take_while(|&&d| d == 0).count() as i32;
481+
466482
strip_leading_zeros(&mut base_10000_digits);
467483
strip_trailing_zeros(&mut base_10000_digits);
468484

469-
// Adjust weight if we stripped leading zeros
470-
let leading_zeros_stripped = ndigits - base_10000_digits.len() as i32;
471-
let final_weight = weight - leading_zeros_stripped;
485+
// Adjust weight only by the number of leading zeros that were removed.
486+
let final_weight = weight - leading_zeros_before_strip;
472487

473488
Ok(PgNumeric::Value {
474489
sign,
@@ -610,6 +625,7 @@ fn format_numeric_value(
610625
#[cfg(test)]
611626
mod tests {
612627
use super::*;
628+
use bytes::BytesMut;
613629

614630
#[test]
615631
fn parse_simple_integer() {
@@ -761,6 +777,76 @@ mod tests {
761777
assert_eq!(format!("{num}"), "0");
762778
}
763779

780+
#[test]
781+
fn zero_canonicalization_basic() {
782+
for s in ["0", "0.0", "000", "000.000"] {
783+
let num = PgNumeric::from_str(s).unwrap();
784+
assert_eq!(num.to_string(), "0");
785+
786+
if let PgNumeric::Value {
787+
sign,
788+
weight,
789+
scale: _,
790+
digits,
791+
} = num
792+
{
793+
assert_eq!(sign, Sign::Positive);
794+
assert_eq!(weight, 0);
795+
assert!(digits.is_empty());
796+
} else {
797+
panic!("Expected Value variant");
798+
}
799+
}
800+
}
801+
802+
#[test]
803+
fn zero_canonicalization_negative_zero() {
804+
for s in ["-0", "-0.00"] {
805+
let num = PgNumeric::from_str(s).unwrap();
806+
assert_eq!(num.to_string(), "0");
807+
808+
if let PgNumeric::Value {
809+
sign,
810+
weight,
811+
scale: _,
812+
digits,
813+
} = num
814+
{
815+
// Normalize to positive zero
816+
assert_eq!(sign, Sign::Positive);
817+
assert_eq!(weight, 0);
818+
assert!(digits.is_empty());
819+
} else {
820+
panic!("Expected Value variant");
821+
}
822+
}
823+
}
824+
825+
#[test]
826+
fn zero_roundtrip_sql() {
827+
for s in ["0", "0.000"] {
828+
let num = PgNumeric::from_str(s).unwrap();
829+
let mut buf = BytesMut::new();
830+
ToSql::to_sql(&num, &Type::NUMERIC, &mut buf).unwrap();
831+
let round = PgNumeric::from_sql(&Type::NUMERIC, &buf).unwrap();
832+
// Internal representation should be canonical-zero with same scale
833+
assert_eq!(num, round);
834+
if let PgNumeric::Value {
835+
sign,
836+
weight,
837+
digits,
838+
..
839+
} = round
840+
{
841+
assert_eq!(sign, Sign::Positive);
842+
assert_eq!(weight, 0);
843+
assert!(digits.is_empty());
844+
} else {
845+
panic!("Expected Value variant");
846+
}
847+
}
848+
}
849+
764850
#[test]
765851
fn display_large_numbers() {
766852
let num = PgNumeric::Value {
@@ -806,4 +892,158 @@ mod tests {
806892
// Should display trailing zeros according to scale
807893
assert!(output.ends_with("0000"));
808894
}
895+
896+
#[test]
897+
fn weight_ignores_trailing_fraction_groups() {
898+
// 0.0012000 → groups: [12, 0], weight must stay at -1 after stripping
899+
let num = PgNumeric::from_str("0.0012000").unwrap();
900+
assert_eq!(num.to_string(), "0.0012000");
901+
902+
if let PgNumeric::Value {
903+
sign,
904+
weight,
905+
scale,
906+
ref digits,
907+
} = num
908+
{
909+
assert_eq!(sign, Sign::Positive);
910+
assert_eq!(weight, -1, "weight should remain -1");
911+
assert_eq!(scale, 7, "scale preserved for display");
912+
assert_eq!(
913+
digits.as_slice(),
914+
&[12],
915+
"trailing base-10000 zero group stripped"
916+
);
917+
} else {
918+
panic!("Expected Value variant");
919+
}
920+
}
921+
922+
#[test]
923+
fn weight_and_groups_boundary_cases() {
924+
// 9999.9999 is exactly two full groups: [9999, 9999], weight 0
925+
let num_1 = PgNumeric::from_str("9999.9999").unwrap();
926+
assert_eq!(num_1.to_string(), "9999.9999");
927+
928+
if let PgNumeric::Value {
929+
sign,
930+
weight,
931+
scale,
932+
ref digits,
933+
} = num_1
934+
{
935+
assert_eq!(sign, Sign::Positive);
936+
assert_eq!(weight, 0);
937+
assert_eq!(scale, 4);
938+
assert_eq!(digits.as_slice(), &[9999, 9999]);
939+
} else {
940+
panic!("Expected Value variant");
941+
}
942+
943+
// 10000.0001 crosses the 10^4 boundary
944+
let num_2 = PgNumeric::from_str("10000.0001").unwrap();
945+
assert_eq!(num_2.to_string(), "10000.0001");
946+
947+
if let PgNumeric::Value {
948+
sign,
949+
weight,
950+
scale,
951+
ref digits,
952+
} = num_2
953+
{
954+
assert_eq!(sign, Sign::Positive);
955+
assert_eq!(weight, 1);
956+
assert_eq!(scale, 4);
957+
// Two integer groups [1, 0] and one fractional group [1]
958+
assert_eq!(digits.as_slice(), &[1, 0, 1]);
959+
} else {
960+
panic!("Expected Value variant");
961+
}
962+
}
963+
964+
#[test]
965+
fn ignores_input_leading_zeros() {
966+
let num = PgNumeric::from_str("0000120.00").unwrap();
967+
assert_eq!(num.to_string(), "120.00");
968+
969+
if let PgNumeric::Value {
970+
sign,
971+
weight,
972+
scale,
973+
ref digits,
974+
} = num
975+
{
976+
assert_eq!(sign, Sign::Positive);
977+
assert_eq!(weight, 0);
978+
assert_eq!(scale, 2);
979+
assert_eq!(digits.as_slice(), &[120]);
980+
} else {
981+
panic!("Expected Value variant");
982+
}
983+
}
984+
985+
#[test]
986+
fn roundtrip_stability() {
987+
let cases = [
988+
"120.00",
989+
"1.2000",
990+
"0.0120",
991+
"9999.9999",
992+
"10000.0001",
993+
"-120.00",
994+
"1200000",
995+
];
996+
997+
for case in &cases {
998+
let parsed = PgNumeric::from_str(case).unwrap();
999+
let printed = parsed.to_string();
1000+
let reparsed = PgNumeric::from_str(&printed).unwrap();
1001+
1002+
// String should be stable across two parses
1003+
assert_eq!(printed, reparsed.to_string(), "unstable print for {case}");
1004+
1005+
// Value representation should be equal across parse/print/parse
1006+
assert_eq!(parsed, reparsed, "unstable internal value for {case}");
1007+
}
1008+
}
1009+
1010+
#[test]
1011+
fn large_integer_weight() {
1012+
// 1,200,000 = 120*10000 + 0 → digits [120, 0] before strip trailing zero
1013+
// We expect trailing zero group to be stripped, weight stays 1.
1014+
let num = PgNumeric::from_str("1200000").unwrap();
1015+
assert_eq!(num.to_string(), "1200000");
1016+
1017+
if let PgNumeric::Value {
1018+
sign,
1019+
weight,
1020+
scale,
1021+
ref digits,
1022+
} = num
1023+
{
1024+
assert_eq!(sign, Sign::Positive);
1025+
assert_eq!(weight, 1);
1026+
assert_eq!(scale, 0);
1027+
assert_eq!(digits.as_slice(), &[120]);
1028+
} else {
1029+
panic!("Expected Value variant");
1030+
}
1031+
}
1032+
1033+
#[test]
1034+
fn tosql_fromsql_special_values() {
1035+
let cases = [
1036+
PgNumeric::NaN,
1037+
PgNumeric::PositiveInfinity,
1038+
PgNumeric::NegativeInfinity,
1039+
];
1040+
1041+
for case in cases {
1042+
let mut buf = BytesMut::new();
1043+
ToSql::to_sql(&case, &Type::NUMERIC, &mut buf).unwrap();
1044+
let round = PgNumeric::from_sql(&Type::NUMERIC, &buf).unwrap();
1045+
assert_eq!(format!("{}", case), format!("{}", round));
1046+
assert_eq!(case, round);
1047+
}
1048+
}
8091049
}

0 commit comments

Comments
 (0)