Skip to content

support Decimal256 type in datafusion-proto #11606

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datafusion/proto-common/proto/datafusion_common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ message Decimal{
int32 scale = 4;
}

message Decimal256Type{
reserved 1, 2;
uint32 precision = 3;
int32 scale = 4;
}

message List{
Field field_type = 1;
}
Expand Down Expand Up @@ -335,6 +341,7 @@ message ArrowType{
TimeUnit TIME64 = 22 ;
IntervalUnit INTERVAL = 23 ;
Decimal DECIMAL = 24 ;
Decimal256Type DECIMAL256 = 36;
List LIST = 25;
List LARGE_LIST = 26;
FixedSizeList FIXED_SIZE_LIST = 27;
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType {
precision,
scale,
}) => DataType::Decimal128(*precision as u8, *scale as i8),
arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type {
precision,
scale,
}) => DataType::Decimal256(*precision as u8, *scale as i8),
arrow_type::ArrowTypeEnum::List(list) => {
let list_type =
list.as_ref().field_type.as_deref().required("field_type")?;
Expand Down
125 changes: 125 additions & 0 deletions datafusion/proto-common/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType {
arrow_type::ArrowTypeEnum::Decimal(v) => {
struct_ser.serialize_field("DECIMAL", v)?;
}
arrow_type::ArrowTypeEnum::Decimal256(v) => {
struct_ser.serialize_field("DECIMAL256", v)?;
}
arrow_type::ArrowTypeEnum::List(v) => {
struct_ser.serialize_field("LIST", v)?;
}
Expand Down Expand Up @@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64",
"INTERVAL",
"DECIMAL",
"DECIMAL256",
"LIST",
"LARGE_LIST",
"LARGELIST",
Expand Down Expand Up @@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
Time64,
Interval,
Decimal,
Decimal256,
List,
LargeList,
FixedSizeList,
Expand Down Expand Up @@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
"TIME64" => Ok(GeneratedField::Time64),
"INTERVAL" => Ok(GeneratedField::Interval),
"DECIMAL" => Ok(GeneratedField::Decimal),
"DECIMAL256" => Ok(GeneratedField::Decimal256),
"LIST" => Ok(GeneratedField::List),
"LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList),
"FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList),
Expand Down Expand Up @@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
return Err(serde::de::Error::duplicate_field("DECIMAL"));
}
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal)
;
}
GeneratedField::Decimal256 => {
if arrow_type_enum__.is_some() {
return Err(serde::de::Error::duplicate_field("DECIMAL256"));
}
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256)
;
}
GeneratedField::List => {
Expand Down Expand Up @@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 {
deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for Decimal256Type {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut len = 0;
if self.precision != 0 {
len += 1;
}
if self.scale != 0 {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?;
if self.precision != 0 {
struct_ser.serialize_field("precision", &self.precision)?;
}
if self.scale != 0 {
struct_ser.serialize_field("scale", &self.scale)?;
}
struct_ser.end()
}
}
impl<'de> serde::Deserialize<'de> for Decimal256Type {
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
"precision",
"scale",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Precision,
Scale,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GeneratedVisitor;

impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = GeneratedField;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

#[allow(unused_variables)]
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
where
E: serde::de::Error,
{
match value {
"precision" => Ok(GeneratedField::Precision),
"scale" => Ok(GeneratedField::Scale),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(GeneratedVisitor)
}
}
struct GeneratedVisitor;
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = Decimal256Type;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct datafusion_common.Decimal256Type")
}

fn visit_map<V>(self, mut map_: V) -> std::result::Result<Decimal256Type, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut precision__ = None;
let mut scale__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Precision => {
if precision__.is_some() {
return Err(serde::de::Error::duplicate_field("precision"));
}
precision__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
;
}
GeneratedField::Scale => {
if scale__.is_some() {
return Err(serde::de::Error::duplicate_field("scale"));
}
scale__ =
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
;
}
}
}
Ok(Decimal256Type {
precision: precision__.unwrap_or_default(),
scale: scale__.unwrap_or_default(),
})
}
}
deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for DfField {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
Expand Down
12 changes: 11 additions & 1 deletion datafusion/proto-common/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Decimal256Type {
#[prost(uint32, tag = "3")]
pub precision: u32,
#[prost(int32, tag = "4")]
pub scale: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
Expand Down Expand Up @@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
Expand Down Expand Up @@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
#[prost(message, tag = "36")]
Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
Expand Down
7 changes: 4 additions & 3 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum {
precision: *precision as u32,
scale: *scale as i32,
}),
DataType::Decimal256(_, _) => {
return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned()))
}
DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type {
precision: *precision as u32,
scale: *scale as i32,
}),
DataType::Map(field, sorted) => {
Self::Map(Box::new(
protobuf::Map {
Expand Down
12 changes: 11 additions & 1 deletion datafusion/proto/src/generated/datafusion_proto_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ pub struct Decimal {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Decimal256Type {
#[prost(uint32, tag = "3")]
pub precision: u32,
#[prost(int32, tag = "4")]
pub scale: i32,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct List {
#[prost(message, optional, boxed, tag = "1")]
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
Expand Down Expand Up @@ -446,7 +454,7 @@ pub struct Decimal256 {
pub struct ArrowType {
#[prost(
oneof = "arrow_type::ArrowTypeEnum",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
)]
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
}
Expand Down Expand Up @@ -516,6 +524,8 @@ pub mod arrow_type {
Interval(i32),
#[prost(message, tag = "24")]
Decimal(super::Decimal),
#[prost(message, tag = "36")]
Decimal256(super::Decimal256Type),
#[prost(message, tag = "25")]
List(::prost::alloc::boxed::Box<super::List>),
#[prost(message, tag = "26")]
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::array::{
use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
DECIMAL256_MAX_PRECISION,
};
use prost::Message;

Expand Down Expand Up @@ -1379,6 +1380,7 @@ fn round_trip_datatype() {
DataType::Utf8,
DataType::LargeUtf8,
DataType::Decimal128(7, 12),
DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
// Recursive list tests
DataType::List(new_arc_field("Level1", DataType::Binary, true)),
DataType::List(new_arc_field(
Expand Down