diff --git a/connectorx/src/destinations/arrow/arrow_assoc.rs b/connectorx/src/destinations/arrow/arrow_assoc.rs index e71a61aa1..731e63991 100644 --- a/connectorx/src/destinations/arrow/arrow_assoc.rs +++ b/connectorx/src/destinations/arrow/arrow_assoc.rs @@ -2,10 +2,10 @@ use super::{ errors::{ArrowDestinationError, Result}, typesystem::{DateTimeWrapperMicro, NaiveDateTimeWrapperMicro, NaiveTimeWrapperMicro}, }; -use crate::constants::SECONDS_IN_DAY; +use crate::{constants::SECONDS_IN_DAY, utils::decimal_to_i128}; use arrow::array::{ - ArrayBuilder, BooleanBuilder, Date32Builder, Float32Builder, Float64Builder, Int16Builder, - Int32Builder, Int64Builder, LargeBinaryBuilder, LargeListBuilder, StringBuilder, + ArrayBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, + Int16Builder, Int32Builder, Int64Builder, LargeBinaryBuilder, LargeListBuilder, StringBuilder, Time64MicrosecondBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampNanosecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, }; @@ -13,6 +13,7 @@ use arrow::datatypes::Field; use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use fehler::throws; +use rust_decimal::Decimal; /// Associate arrow builder with native type pub trait ArrowAssoc { @@ -71,6 +72,51 @@ impl_arrow_assoc!(f32, ArrowDataType::Float32, Float32Builder); impl_arrow_assoc!(f64, ArrowDataType::Float64, Float64Builder); impl_arrow_assoc!(bool, ArrowDataType::Boolean, BooleanBuilder); +const DEFAULT_ARROW_DECIMAL_PRECISION: u8 = 38; +const DEFAULT_ARROW_DECIMAL_SCALE: i8 = 10; +const DEFAULT_ARROW_DECIMAL: ArrowDataType = + ArrowDataType::Decimal128(DEFAULT_ARROW_DECIMAL_PRECISION, DEFAULT_ARROW_DECIMAL_SCALE); + +impl ArrowAssoc for Decimal { + type Builder = Decimal128Builder; + + fn builder(nrows: usize) -> Self::Builder { + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL) + } + + fn append(builder: &mut Self::Builder, value: Self) -> Result<()> { + builder.append_value(decimal_to_i128(value, DEFAULT_ARROW_DECIMAL_SCALE as u32)?); + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new(header, DEFAULT_ARROW_DECIMAL, false) + } +} + +impl ArrowAssoc for Option { + type Builder = Decimal128Builder; + + fn builder(nrows: usize) -> Self::Builder { + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL) + } + + fn append(builder: &mut Self::Builder, value: Self) -> Result<()> { + match value { + Some(v) => builder.append_option(Some(decimal_to_i128( + v, + DEFAULT_ARROW_DECIMAL_SCALE as u32, + )?)), + None => builder.append_null(), + } + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new(header, DEFAULT_ARROW_DECIMAL, true) + } +} + impl ArrowAssoc for &str { type Builder = StringBuilder; @@ -486,6 +532,93 @@ impl ArrowAssoc for Vec { } } +impl ArrowAssoc for Option>> { + type Builder = LargeListBuilder; + + fn builder(nrows: usize) -> Self::Builder { + LargeListBuilder::with_capacity( + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL), + nrows, + ) + } + + fn append(builder: &mut Self::Builder, value: Self) -> Result<()> { + match value { + Some(vals) => { + let mut list = vec![]; + + for val in vals { + match val { + Some(v) => { + list.push(Some(decimal_to_i128( + v, + DEFAULT_ARROW_DECIMAL_SCALE as u32, + )?)); + } + None => list.push(None), + } + } + + builder.append_value(list); + } + None => builder.append_null(), + }; + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new( + header, + ArrowDataType::LargeList(std::sync::Arc::new(Field::new_list_field( + DEFAULT_ARROW_DECIMAL, + true, + ))), + true, + ) + } +} + +impl ArrowAssoc for Vec> { + type Builder = LargeListBuilder; + + fn builder(nrows: usize) -> Self::Builder { + LargeListBuilder::with_capacity( + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL), + nrows, + ) + } + + fn append(builder: &mut Self::Builder, vals: Self) -> Result<()> { + let mut list = vec![]; + + for val in vals { + match val { + Some(v) => { + list.push(Some(decimal_to_i128( + v, + DEFAULT_ARROW_DECIMAL_SCALE as u32, + )?)); + } + None => list.push(None), + } + } + + builder.append_value(list); + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new( + header, + ArrowDataType::LargeList(std::sync::Arc::new(Field::new_list_field( + DEFAULT_ARROW_DECIMAL, + false, + ))), + false, + ) + } +} + macro_rules! impl_arrow_array_assoc { ($T:ty, $AT:expr, $B:ident) => { impl ArrowAssoc for $T { diff --git a/connectorx/src/destinations/arrow/typesystem.rs b/connectorx/src/destinations/arrow/typesystem.rs index 3eae46bb0..2b190c5c6 100644 --- a/connectorx/src/destinations/arrow/typesystem.rs +++ b/connectorx/src/destinations/arrow/typesystem.rs @@ -1,5 +1,6 @@ use crate::impl_typesystem; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use rust_decimal::Decimal; #[derive(Debug, Clone, Copy)] pub struct DateTimeWrapperMicro(pub DateTime); @@ -20,6 +21,7 @@ pub enum ArrowTypeSystem { UInt64(bool), Float32(bool), Float64(bool), + Decimal(bool), Boolean(bool), LargeUtf8(bool), LargeBinary(bool), @@ -40,6 +42,7 @@ pub enum ArrowTypeSystem { UInt64Array(bool), Float32Array(bool), Float64Array(bool), + DecimalArray(bool), } impl_typesystem! { @@ -53,6 +56,7 @@ impl_typesystem! { { UInt64 => u64 } { Float32 => f32 } { Float64 => f64 } + { Decimal => Decimal } { Boolean => bool } { LargeUtf8 => String } { LargeBinary => Vec } @@ -73,5 +77,6 @@ impl_typesystem! { { UInt64Array => Vec> } { Float32Array => Vec> } { Float64Array => Vec> } + { DecimalArray => Vec> } } } diff --git a/connectorx/src/destinations/arrowstream/arrow_assoc.rs b/connectorx/src/destinations/arrowstream/arrow_assoc.rs index be119a7f5..cc4409fdb 100644 --- a/connectorx/src/destinations/arrowstream/arrow_assoc.rs +++ b/connectorx/src/destinations/arrowstream/arrow_assoc.rs @@ -1,15 +1,21 @@ use super::errors::{ArrowDestinationError, Result}; use crate::constants::SECONDS_IN_DAY; +use crate::utils::decimal_to_i128; use arrow::array::{ - ArrayBuilder, BooleanBuilder, Date32Builder, Date64Builder, Float32Builder, Float64Builder, - Int32Builder, Int64Builder, LargeBinaryBuilder, StringBuilder, Time64NanosecondBuilder, - TimestampNanosecondBuilder, UInt32Builder, UInt64Builder, + ArrayBuilder, BooleanBuilder, Date32Builder, Date64Builder, Decimal128Builder, Float32Builder, + Float64Builder, Int32Builder, Int64Builder, LargeBinaryBuilder, StringBuilder, + Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder, UInt64Builder, }; use arrow::datatypes::Field; use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use fehler::throws; +use rust_decimal::Decimal; +const DEFAULT_ARROW_DECIMAL_PRECISION: u8 = 38; +const DEFAULT_ARROW_DECIMAL_SCALE: i8 = 10; +const DEFAULT_ARROW_DECIMAL: ArrowDataType = + ArrowDataType::Decimal128(DEFAULT_ARROW_DECIMAL_PRECISION, DEFAULT_ARROW_DECIMAL_SCALE); /// Associate arrow builder with native type pub trait ArrowAssoc { type Builder: ArrayBuilder + Send; @@ -65,6 +71,46 @@ impl_arrow_assoc!(f32, ArrowDataType::Float32, Float32Builder); impl_arrow_assoc!(f64, ArrowDataType::Float64, Float64Builder); impl_arrow_assoc!(bool, ArrowDataType::Boolean, BooleanBuilder); +impl ArrowAssoc for Decimal { + type Builder = Decimal128Builder; + + fn builder(nrows: usize) -> Self::Builder { + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL) + } + + fn append(builder: &mut Self::Builder, value: Self) -> Result<()> { + builder.append_value(decimal_to_i128(value, DEFAULT_ARROW_DECIMAL_SCALE as u32)?); + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new(header, DEFAULT_ARROW_DECIMAL, false) + } +} + +impl ArrowAssoc for Option { + type Builder = Decimal128Builder; + + fn builder(nrows: usize) -> Self::Builder { + Decimal128Builder::with_capacity(nrows).with_data_type(DEFAULT_ARROW_DECIMAL) + } + + fn append(builder: &mut Self::Builder, value: Self) -> Result<()> { + match value { + Some(v) => builder.append_option(Some(decimal_to_i128( + v, + DEFAULT_ARROW_DECIMAL_SCALE as u32, + )?)), + None => builder.append_null(), + } + Ok(()) + } + + fn field(header: &str) -> Field { + Field::new(header, DEFAULT_ARROW_DECIMAL, true) + } +} + impl ArrowAssoc for &str { type Builder = StringBuilder; diff --git a/connectorx/src/destinations/arrowstream/typesystem.rs b/connectorx/src/destinations/arrowstream/typesystem.rs index a6997a2ba..61e3d6d7a 100644 --- a/connectorx/src/destinations/arrowstream/typesystem.rs +++ b/connectorx/src/destinations/arrowstream/typesystem.rs @@ -1,5 +1,6 @@ use crate::impl_typesystem; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use rust_decimal::Decimal; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum ArrowTypeSystem { @@ -9,6 +10,7 @@ pub enum ArrowTypeSystem { UInt64(bool), Float32(bool), Float64(bool), + Decimal(bool), Boolean(bool), LargeUtf8(bool), LargeBinary(bool), @@ -27,6 +29,7 @@ impl_typesystem! { { UInt64 => u64 } { Float64 => f64 } { Float32 => f32 } + { Decimal => Decimal } { Boolean => bool } { LargeUtf8 => String } { LargeBinary => Vec } diff --git a/connectorx/src/transports/postgres_arrow.rs b/connectorx/src/transports/postgres_arrow.rs index 558fa7d81..500e1432c 100644 --- a/connectorx/src/transports/postgres_arrow.rs +++ b/connectorx/src/transports/postgres_arrow.rs @@ -46,7 +46,7 @@ macro_rules! impl_postgres_transport { mappings = { { Float4[f32] => Float32[f32] | conversion auto } { Float8[f64] => Float64[f64] | conversion auto } - { Numeric[Decimal] => Float64[f64] | conversion option } + { Numeric[Decimal] => Decimal[Decimal] | conversion auto } { Int2[i16] => Int16[i16] | conversion auto } { Int4[i32] => Int32[i32] | conversion auto } { Int8[i64] => Int64[i64] | conversion auto } @@ -73,7 +73,7 @@ macro_rules! impl_postgres_transport { { Int8Array[Vec>] => Int64Array[Vec>] | conversion auto } { Float4Array[Vec>] => Float32Array[Vec>] | conversion auto } { Float8Array[Vec>] => Float64Array[Vec>] | conversion auto } - { NumericArray[Vec>] => Float64Array[Vec>] | conversion option } + { NumericArray[Vec>] => DecimalArray[Vec>] | conversion auto } } ); } @@ -125,17 +125,4 @@ impl TypeConversion for PostgresArrowTransport { fn convert(val: Value) -> String { val.to_string() } -} - -impl TypeConversion>, Vec>> for PostgresArrowTransport { - fn convert(val: Vec>) -> Vec> { - val.into_iter() - .map(|v| { - v.map(|v| { - v.to_f64() - .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", v)) - }) - }) - .collect() - } -} +} \ No newline at end of file diff --git a/connectorx/src/transports/postgres_arrowstream.rs b/connectorx/src/transports/postgres_arrowstream.rs index 7d1c20d9e..5fdc56606 100644 --- a/connectorx/src/transports/postgres_arrowstream.rs +++ b/connectorx/src/transports/postgres_arrowstream.rs @@ -9,7 +9,6 @@ use crate::sources::postgres::{ }; use crate::typesystem::TypeConversion; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; -use num_traits::ToPrimitive; use postgres::NoTls; use postgres_openssl::MakeTlsConnector; use rust_decimal::Decimal; @@ -43,7 +42,7 @@ macro_rules! impl_postgres_transport { mappings = { { Float4[f32] => Float64[f64] | conversion auto } { Float8[f64] => Float64[f64] | conversion auto } - { Numeric[Decimal] => Float64[f64] | conversion option } + { Numeric[Decimal] => Decimal[Decimal] | conversion option } { Int2[i16] => Int64[i64] | conversion auto } { Int4[i32] => Int64[i64] | conversion auto } { Int8[i64] => Int64[i64] | conversion auto } @@ -81,10 +80,9 @@ impl TypeConversion for PostgresArrowTransport { } } -impl TypeConversion for PostgresArrowTransport { - fn convert(val: Decimal) -> f64 { - val.to_f64() - .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) +impl TypeConversion for PostgresArrowTransport { + fn convert(val: Decimal) -> Decimal { + val } } diff --git a/connectorx/src/utils.rs b/connectorx/src/utils.rs index ec3919ebb..03a7c883d 100644 --- a/connectorx/src/utils.rs +++ b/connectorx/src/utils.rs @@ -1,3 +1,5 @@ +use anyhow::Result; +use rust_decimal::Decimal; use std::ops::{Deref, DerefMut}; pub struct DummyBox(pub T); @@ -15,3 +17,18 @@ impl DerefMut for DummyBox { &mut self.0 } } + +pub fn decimal_to_i128(mut v: Decimal, scale: u32) -> Result { + v.rescale(scale); + + let v_scale = v.scale(); + if v_scale != scale as u32 { + return Err(anyhow::anyhow!( + "decimal scale is not equal to expected scale, got: {} expected: {}", + v_scale, + scale + )); + } + + Ok(v.mantissa()) +} diff --git a/connectorx/tests/test_postgres.rs b/connectorx/tests/test_postgres.rs index dca520a45..cee66ab60 100644 --- a/connectorx/tests/test_postgres.rs +++ b/connectorx/tests/test_postgres.rs @@ -1,8 +1,9 @@ use arrow::{ array::{ - BooleanArray, BooleanBuilder, Date32Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, LargeBinaryArray, LargeListArray, LargeListBuilder, StringArray, - StringBuilder, Time64MicrosecondArray, TimestampMicrosecondArray, + Array, BooleanArray, BooleanBuilder, Date32Array, Decimal128Array, Decimal128Builder, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray, + LargeListArray, LargeListBuilder, StringArray, StringBuilder, Time64MicrosecondArray, + TimestampMicrosecondArray, }, datatypes::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type}, record_batch::RecordBatch, @@ -25,6 +26,46 @@ use postgres::NoTls; use std::env; use url::Url; +#[test] +fn test_types_simple_postgres_aa() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("POSTGRES_URL").unwrap(); + + let vars = vec!["test_numeric"].join(","); + + let queries = [CXQuery::naked(format!("select {vars} from test_types"))]; + let url = Url::parse(dburl.as_str()).unwrap(); + let (config, _tls) = rewrite_tls_args(&url).unwrap(); + let builder = PostgresSource::::new(config, NoTls, 2).unwrap(); + let mut destination = ArrowDestination::new(); + let dispatcher = Dispatcher::<_, _, PostgresArrowTransport>::new( + builder, + &mut destination, + &queries, + Some(String::from("select * from test_types")), + ); + + dispatcher.run().expect("run dispatcher"); + + let result = destination.arrow().unwrap(); + + arrow::util::pretty::print_batches(&result).unwrap(); +} + +#[test] +fn test_decimal_128() { + let mut builder = Decimal128Builder::new(); + builder.append_value(1); + builder.append_value(1234567890); + builder.append_value(1234567890); + builder.append_value(1234567890); + builder.append_value(1234567890); + let decimal = builder.finish(); + + println!("decimal: {:#?}", decimal.value_as_string(0)); +} + #[test] fn load_and_parse() { let _ = env_logger::builder().is_test(true).try_init(); @@ -763,18 +804,19 @@ pub fn verify_arrow_type_results(result: Vec, protocol: &str) { // test_numeric col += 1; - assert!(result[0] + let actual = result[0] .column(col) .as_any() - .downcast_ref::() - .unwrap() - .eq(&Float64Array::from(vec![ - Some(0.01), - Some(521.34), - Some(0.0), - Some(-112.3), - None, - ]))); + .downcast_ref::() + .unwrap(); + let expected = build_decimal_array(vec![ + Some(100000000), + Some(5213400000000), + Some(0), + Some(-1123000000000), + None, + ]); + assert_eq!(actual, &expected); // test_bpchar col += 1; @@ -970,20 +1012,28 @@ pub fn verify_arrow_type_results(result: Vec, protocol: &str) { // test_narray col += 1; - assert!(result[0] + let actual = result[0] .column(col) .as_any() .downcast_ref::() - .unwrap() - .eq(&LargeListArray::from_iter_primitive::( - vec![ - Some(vec![Some(0.01), Some(521.23)]), - Some(vec![Some(0.12), Some(333.33), Some(22.22)]), - Some(vec![]), - Some(vec![Some(0.0), None, Some(-112.1)]), - None, - ] - ))); + .unwrap(); + + let mut expected = LargeListBuilder::new( + Decimal128Builder::new() + .with_precision_and_scale(38, 10) + .unwrap(), + ); + expected.append_value(vec![Some(100000000), Some(5212300000000)]); + expected.append_value(vec![ + Some(1200000000), + Some(3333300000000), + Some(222200000000), + ]); + expected.append_value(vec![]); + expected.append_value(vec![Some(0), None, Some(-1121000000000)]); + expected.append_null(); + + assert!(actual.eq(&expected.finish())); // test_boolarray (from_iter_primitive not available for boolean) col += 1; @@ -1313,3 +1363,18 @@ fn test_postgres_partitioned_pre_execution_queries() { Some(&2252) ); } + +fn build_decimal_array(vals: Vec>) -> Decimal128Array { + let mut builder = Decimal128Builder::new() + .with_precision_and_scale(38, 10) + .unwrap(); + + for val in vals { + match val { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } + + builder.finish() +}