Skip to content

Commit 989fc1d

Browse files
Correctly parse decimal/numeric types
1 parent 9be64a7 commit 989fc1d

File tree

2 files changed

+291
-44
lines changed

2 files changed

+291
-44
lines changed

etl-destinations/src/delta/schema.rs

Lines changed: 107 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,30 @@ use etl::types::{
1919
};
2020
use std::sync::Arc;
2121

22+
/// Extract numeric precision from Postgres atttypmod
23+
/// Based on: https://stackoverflow.com/questions/72725508/how-to-calculate-numeric-precision-and-other-vals-from-atttypmod
24+
fn extract_numeric_precision(atttypmod: i32) -> u8 {
25+
if atttypmod == -1 {
26+
// No limit specified, use maximum precision
27+
38
28+
} else {
29+
let precision = ((atttypmod - 4) >> 16) & 65535;
30+
std::cmp::min(precision as u8, 38) // Cap at Arrow's max precision
31+
}
32+
}
33+
34+
/// Extract numeric scale from Postgres atttypmod
35+
/// Based on: https://stackoverflow.com/questions/72725508/how-to-calculate-numeric-precision-and-other-vals-from-atttypmod
36+
fn extract_numeric_scale(atttypmod: i32) -> i8 {
37+
if atttypmod == -1 {
38+
// No limit specified, use reasonable default scale
39+
18
40+
} else {
41+
let scale = (atttypmod - 4) & 65535;
42+
std::cmp::min(scale as i8, 38) // Cap at reasonable scale
43+
}
44+
}
45+
2246
/// Converts TableRows to Arrow RecordBatch for Delta Lake writes
2347
pub struct TableRowEncoder;
2448

@@ -57,7 +81,8 @@ impl TableRowEncoder {
5781
.column_schemas
5882
.iter()
5983
.map(|col_schema| {
60-
let data_type = Self::postgres_type_to_arrow_type(&col_schema.typ);
84+
let data_type =
85+
Self::postgres_type_to_arrow_type(&col_schema.typ, col_schema.modifier);
6186
ArrowField::new(&col_schema.name, data_type, col_schema.nullable)
6287
})
6388
.collect();
@@ -66,7 +91,7 @@ impl TableRowEncoder {
6691
}
6792

6893
/// Map Postgres types to appropriate Arrow types
69-
pub(crate) fn postgres_type_to_arrow_type(pg_type: &PGType) -> ArrowDataType {
94+
pub(crate) fn postgres_type_to_arrow_type(pg_type: &PGType, modifier: i32) -> ArrowDataType {
7095
match *pg_type {
7196
// Boolean types
7297
PGType::BOOL => ArrowDataType::Boolean,
@@ -131,15 +156,20 @@ impl TableRowEncoder {
131156
ArrowDataType::Float64,
132157
true,
133158
))),
134-
135-
// Decimal types - use high precision for NUMERIC
136-
PGType::NUMERIC => ArrowDataType::Decimal128(38, 18), // Max precision, reasonable scale
137-
PGType::NUMERIC_ARRAY => ArrowDataType::List(Arc::new(ArrowField::new(
138-
"item",
139-
ArrowDataType::Decimal128(38, 18),
140-
true,
141-
))),
142-
159+
PGType::NUMERIC => {
160+
let precision = extract_numeric_precision(modifier);
161+
let scale = extract_numeric_scale(modifier);
162+
ArrowDataType::Decimal128(precision, scale)
163+
}
164+
PGType::NUMERIC_ARRAY => {
165+
let precision = extract_numeric_precision(modifier);
166+
let scale = extract_numeric_scale(modifier);
167+
ArrowDataType::List(Arc::new(ArrowField::new(
168+
"item",
169+
ArrowDataType::Decimal128(precision, scale),
170+
true,
171+
)))
172+
}
143173
// Date/Time types
144174
PGType::DATE => ArrowDataType::Date32,
145175
PGType::DATE_ARRAY => ArrowDataType::List(Arc::new(ArrowField::new(
@@ -623,39 +653,38 @@ impl TableRowEncoder {
623653
/// Convert cells to decimal128 array
624654
fn convert_to_decimal128_array(
625655
cells: Vec<&PGCell>,
626-
_precision: u8,
627-
_scale: i8,
656+
precision: u8,
657+
scale: i8,
628658
) -> Result<ArrayRef, ArrowError> {
629659
let values: Vec<Option<i128>> = cells
630660
.iter()
631661
.map(|cell| match cell {
632662
PGCell::Null => None,
633663
PGCell::Numeric(n) => {
634-
// Convert PgNumeric to decimal128
635664
// This is a simplified conversion - ideally we'd preserve the exact decimal representation
636665
if let Ok(string_val) = n.to_string().parse::<f64>() {
637666
// Scale up by the scale factor and convert to i128
638-
let scaled = (string_val * 10_f64.powi(_scale as i32)) as i128;
667+
let scaled = (string_val * 10_f64.powi(scale as i32)) as i128;
639668
Some(scaled)
640669
} else {
641670
None
642671
}
643672
}
644-
PGCell::I16(i) => Some(*i as i128 * 10_i128.pow(_scale as u32)),
645-
PGCell::I32(i) => Some(*i as i128 * 10_i128.pow(_scale as u32)),
646-
PGCell::I64(i) => Some(*i as i128 * 10_i128.pow(_scale as u32)),
647-
PGCell::U32(i) => Some(*i as i128 * 10_i128.pow(_scale as u32)),
673+
PGCell::I16(i) => Some(*i as i128 * 10_i128.pow(scale as u32)),
674+
PGCell::I32(i) => Some(*i as i128 * 10_i128.pow(scale as u32)),
675+
PGCell::I64(i) => Some(*i as i128 * 10_i128.pow(scale as u32)),
676+
PGCell::U32(i) => Some(*i as i128 * 10_i128.pow(scale as u32)),
648677
PGCell::F32(f) => {
649-
let scaled = (*f as f64 * 10_f64.powi(_scale as i32)) as i128;
678+
let scaled = (*f as f64 * 10_f64.powi(scale as i32)) as i128;
650679
Some(scaled)
651680
}
652681
PGCell::F64(f) => {
653-
let scaled = (f * 10_f64.powi(_scale as i32)) as i128;
682+
let scaled = (f * 10_f64.powi(scale as i32)) as i128;
654683
Some(scaled)
655684
}
656685
PGCell::String(s) => {
657686
if let Ok(val) = s.parse::<f64>() {
658-
let scaled = (val * 10_f64.powi(_scale as i32)) as i128;
687+
let scaled = (val * 10_f64.powi(scale as i32)) as i128;
659688
Some(scaled)
660689
} else {
661690
None
@@ -664,7 +693,11 @@ impl TableRowEncoder {
664693
_ => None,
665694
})
666695
.collect();
667-
Ok(Arc::new(Decimal128Array::from(values)))
696+
697+
let decimal_type = ArrowDataType::Decimal128(precision, scale);
698+
Ok(Arc::new(
699+
Decimal128Array::from(values).with_data_type(decimal_type),
700+
))
668701
}
669702

670703
/// Convert cells to list array for array types
@@ -713,7 +746,7 @@ impl TableRowEncoder {
713746
/// Convert a Postgres type to Delta DataType using delta-kernel's conversion traits
714747
#[allow(dead_code)]
715748
pub(crate) fn postgres_type_to_delta(typ: &PGType) -> Result<DeltaDataType, ArrowError> {
716-
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(typ);
749+
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(typ, -1);
717750
DeltaDataType::try_from_arrow(&arrow_type)
718751
}
719752

@@ -723,7 +756,7 @@ pub(crate) fn postgres_to_delta_schema(schema: &PGTableSchema) -> DeltaResult<De
723756
.column_schemas
724757
.iter()
725758
.map(|col| {
726-
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(&col.typ);
759+
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(&col.typ, col.modifier);
727760
let delta_data_type = DeltaDataType::try_from_arrow(&arrow_type)
728761
.map_err(|e| deltalake::DeltaTableError::Generic(e.to_string()))?;
729762
Ok(DeltaStructField::new(
@@ -740,6 +773,7 @@ pub(crate) fn postgres_to_delta_schema(schema: &PGTableSchema) -> DeltaResult<De
740773
#[cfg(test)]
741774
mod tests {
742775
use super::*;
776+
use delta_kernel::schema::{DecimalType, PrimitiveType};
743777

744778
#[test]
745779
fn test_scalar_mappings() {
@@ -780,11 +814,10 @@ mod tests {
780814
postgres_type_to_delta(&PGType::BYTEA).unwrap(),
781815
DeltaDataType::BINARY
782816
));
783-
// Test NUMERIC mapping - delta-kernel should handle the conversion
784-
let numeric_result = postgres_type_to_delta(&PGType::NUMERIC).unwrap();
785-
// The actual result depends on delta-kernel's conversion implementation
786-
// For now, just verify the conversion succeeds
787-
println!("NUMERIC maps to: {:?}", numeric_result);
817+
assert!(matches!(
818+
postgres_type_to_delta(&PGType::NUMERIC).unwrap(),
819+
DeltaDataType::Primitive(PrimitiveType::Decimal(DecimalType { .. }))
820+
));
788821
}
789822

790823
#[test]
@@ -881,7 +914,7 @@ mod tests {
881914
);
882915

883916
// Test that we can convert back to Arrow
884-
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(&pg_type);
917+
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(&pg_type, -1);
885918
let roundtrip_delta = DeltaDataType::try_from_arrow(&arrow_type);
886919
assert!(
887920
roundtrip_delta.is_ok(),
@@ -920,45 +953,79 @@ mod tests {
920953
assert_eq!(batch.num_columns(), 12); // All test columns
921954
}
922955

956+
#[test]
957+
fn test_decimal_precision_scale_extraction() {
958+
// Test specific atttypmod values from the Stack Overflow example
959+
// https://stackoverflow.com/questions/72725508/how-to-calculate-numeric-precision-and-other-vals-from-atttypmod
960+
961+
// NUMERIC(5,2) -> atttypmod = 327686
962+
assert_eq!(extract_numeric_precision(327686), 5);
963+
assert_eq!(extract_numeric_scale(327686), 2);
964+
965+
// NUMERIC(5,1) -> atttypmod = 327685
966+
assert_eq!(extract_numeric_precision(327685), 5);
967+
assert_eq!(extract_numeric_scale(327685), 1);
968+
969+
// NUMERIC(6,3) -> atttypmod = 393223
970+
assert_eq!(extract_numeric_precision(393223), 6);
971+
assert_eq!(extract_numeric_scale(393223), 3);
972+
973+
// NUMERIC(4,4) -> atttypmod = 262152
974+
assert_eq!(extract_numeric_precision(262152), 4);
975+
assert_eq!(extract_numeric_scale(262152), 4);
976+
977+
// Test -1 (no limit)
978+
assert_eq!(extract_numeric_precision(-1), 38); // Max precision
979+
assert_eq!(extract_numeric_scale(-1), 18); // Default scale
980+
981+
let arrow_type = TableRowEncoder::postgres_type_to_arrow_type(&PGType::NUMERIC, 327686);
982+
if let ArrowDataType::Decimal128(precision, scale) = arrow_type {
983+
assert_eq!(precision, 5);
984+
assert_eq!(scale, 2);
985+
} else {
986+
panic!("Expected Decimal128 type, got: {:?}", arrow_type);
987+
}
988+
}
989+
923990
#[test]
924991
fn test_postgres_type_to_arrow_type_mapping() {
925992
// Test basic types
926993
assert_eq!(
927-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::BOOL),
994+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::BOOL, -1),
928995
ArrowDataType::Boolean
929996
);
930997
assert_eq!(
931-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT4),
998+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT4, -1),
932999
ArrowDataType::Int32
9331000
);
9341001
assert_eq!(
935-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT8),
1002+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT8, -1),
9361003
ArrowDataType::Int64
9371004
);
9381005
assert_eq!(
939-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::FLOAT8),
1006+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::FLOAT8, -1),
9401007
ArrowDataType::Float64
9411008
);
9421009
assert_eq!(
943-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::TEXT),
1010+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::TEXT, -1),
9441011
ArrowDataType::Utf8
9451012
);
9461013
assert_eq!(
947-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::DATE),
1014+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::DATE, -1),
9481015
ArrowDataType::Date32
9491016
);
9501017
assert_eq!(
951-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::TIME),
1018+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::TIME, -1),
9521019
ArrowDataType::Timestamp(TimeUnit::Microsecond, None)
9531020
);
9541021
assert_eq!(
955-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::BYTEA),
1022+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::BYTEA, -1),
9561023
ArrowDataType::Binary
9571024
);
9581025

9591026
// Test array types
9601027
if let ArrowDataType::List(field) =
961-
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT4_ARRAY)
1028+
TableRowEncoder::postgres_type_to_arrow_type(&PGType::INT4_ARRAY, -1)
9621029
{
9631030
assert_eq!(*field.data_type(), ArrowDataType::Int32);
9641031
} else {

0 commit comments

Comments
 (0)