Skip to content

Commit 77311a5

Browse files
authored
support Decimal256 type in datafusion-proto (#11606)
1 parent deef834 commit 77311a5

File tree

7 files changed

+164
-5
lines changed

7 files changed

+164
-5
lines changed

datafusion/proto-common/proto/datafusion_common.proto

+7
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ message Decimal{
130130
int32 scale = 4;
131131
}
132132

133+
message Decimal256Type{
134+
reserved 1, 2;
135+
uint32 precision = 3;
136+
int32 scale = 4;
137+
}
138+
133139
message List{
134140
Field field_type = 1;
135141
}
@@ -335,6 +341,7 @@ message ArrowType{
335341
TimeUnit TIME64 = 22 ;
336342
IntervalUnit INTERVAL = 23 ;
337343
Decimal DECIMAL = 24 ;
344+
Decimal256Type DECIMAL256 = 36;
338345
List LIST = 25;
339346
List LARGE_LIST = 26;
340347
FixedSizeList FIXED_SIZE_LIST = 27;

datafusion/proto-common/src/from_proto/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType {
260260
precision,
261261
scale,
262262
}) => DataType::Decimal128(*precision as u8, *scale as i8),
263+
arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type {
264+
precision,
265+
scale,
266+
}) => DataType::Decimal256(*precision as u8, *scale as i8),
263267
arrow_type::ArrowTypeEnum::List(list) => {
264268
let list_type =
265269
list.as_ref().field_type.as_deref().required("field_type")?;

datafusion/proto-common/src/generated/pbjson.rs

+125
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ impl serde::Serialize for ArrowType {
175175
arrow_type::ArrowTypeEnum::Decimal(v) => {
176176
struct_ser.serialize_field("DECIMAL", v)?;
177177
}
178+
arrow_type::ArrowTypeEnum::Decimal256(v) => {
179+
struct_ser.serialize_field("DECIMAL256", v)?;
180+
}
178181
arrow_type::ArrowTypeEnum::List(v) => {
179182
struct_ser.serialize_field("LIST", v)?;
180183
}
@@ -241,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
241244
"TIME64",
242245
"INTERVAL",
243246
"DECIMAL",
247+
"DECIMAL256",
244248
"LIST",
245249
"LARGE_LIST",
246250
"LARGELIST",
@@ -282,6 +286,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
282286
Time64,
283287
Interval,
284288
Decimal,
289+
Decimal256,
285290
List,
286291
LargeList,
287292
FixedSizeList,
@@ -338,6 +343,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
338343
"TIME64" => Ok(GeneratedField::Time64),
339344
"INTERVAL" => Ok(GeneratedField::Interval),
340345
"DECIMAL" => Ok(GeneratedField::Decimal),
346+
"DECIMAL256" => Ok(GeneratedField::Decimal256),
341347
"LIST" => Ok(GeneratedField::List),
342348
"LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList),
343349
"FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList),
@@ -556,6 +562,13 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
556562
return Err(serde::de::Error::duplicate_field("DECIMAL"));
557563
}
558564
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal)
565+
;
566+
}
567+
GeneratedField::Decimal256 => {
568+
if arrow_type_enum__.is_some() {
569+
return Err(serde::de::Error::duplicate_field("DECIMAL256"));
570+
}
571+
arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256)
559572
;
560573
}
561574
GeneratedField::List => {
@@ -2849,6 +2862,118 @@ impl<'de> serde::Deserialize<'de> for Decimal256 {
28492862
deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor)
28502863
}
28512864
}
2865+
impl serde::Serialize for Decimal256Type {
2866+
#[allow(deprecated)]
2867+
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
2868+
where
2869+
S: serde::Serializer,
2870+
{
2871+
use serde::ser::SerializeStruct;
2872+
let mut len = 0;
2873+
if self.precision != 0 {
2874+
len += 1;
2875+
}
2876+
if self.scale != 0 {
2877+
len += 1;
2878+
}
2879+
let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?;
2880+
if self.precision != 0 {
2881+
struct_ser.serialize_field("precision", &self.precision)?;
2882+
}
2883+
if self.scale != 0 {
2884+
struct_ser.serialize_field("scale", &self.scale)?;
2885+
}
2886+
struct_ser.end()
2887+
}
2888+
}
2889+
impl<'de> serde::Deserialize<'de> for Decimal256Type {
2890+
#[allow(deprecated)]
2891+
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
2892+
where
2893+
D: serde::Deserializer<'de>,
2894+
{
2895+
const FIELDS: &[&str] = &[
2896+
"precision",
2897+
"scale",
2898+
];
2899+
2900+
#[allow(clippy::enum_variant_names)]
2901+
enum GeneratedField {
2902+
Precision,
2903+
Scale,
2904+
}
2905+
impl<'de> serde::Deserialize<'de> for GeneratedField {
2906+
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
2907+
where
2908+
D: serde::Deserializer<'de>,
2909+
{
2910+
struct GeneratedVisitor;
2911+
2912+
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
2913+
type Value = GeneratedField;
2914+
2915+
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2916+
write!(formatter, "expected one of: {:?}", &FIELDS)
2917+
}
2918+
2919+
#[allow(unused_variables)]
2920+
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
2921+
where
2922+
E: serde::de::Error,
2923+
{
2924+
match value {
2925+
"precision" => Ok(GeneratedField::Precision),
2926+
"scale" => Ok(GeneratedField::Scale),
2927+
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
2928+
}
2929+
}
2930+
}
2931+
deserializer.deserialize_identifier(GeneratedVisitor)
2932+
}
2933+
}
2934+
struct GeneratedVisitor;
2935+
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
2936+
type Value = Decimal256Type;
2937+
2938+
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2939+
formatter.write_str("struct datafusion_common.Decimal256Type")
2940+
}
2941+
2942+
fn visit_map<V>(self, mut map_: V) -> std::result::Result<Decimal256Type, V::Error>
2943+
where
2944+
V: serde::de::MapAccess<'de>,
2945+
{
2946+
let mut precision__ = None;
2947+
let mut scale__ = None;
2948+
while let Some(k) = map_.next_key()? {
2949+
match k {
2950+
GeneratedField::Precision => {
2951+
if precision__.is_some() {
2952+
return Err(serde::de::Error::duplicate_field("precision"));
2953+
}
2954+
precision__ =
2955+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
2956+
;
2957+
}
2958+
GeneratedField::Scale => {
2959+
if scale__.is_some() {
2960+
return Err(serde::de::Error::duplicate_field("scale"));
2961+
}
2962+
scale__ =
2963+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
2964+
;
2965+
}
2966+
}
2967+
}
2968+
Ok(Decimal256Type {
2969+
precision: precision__.unwrap_or_default(),
2970+
scale: scale__.unwrap_or_default(),
2971+
})
2972+
}
2973+
}
2974+
deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor)
2975+
}
2976+
}
28522977
impl serde::Serialize for DfField {
28532978
#[allow(deprecated)]
28542979
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>

datafusion/proto-common/src/generated/prost.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ pub struct Decimal {
140140
}
141141
#[allow(clippy::derive_partial_eq_without_eq)]
142142
#[derive(Clone, PartialEq, ::prost::Message)]
143+
pub struct Decimal256Type {
144+
#[prost(uint32, tag = "3")]
145+
pub precision: u32,
146+
#[prost(int32, tag = "4")]
147+
pub scale: i32,
148+
}
149+
#[allow(clippy::derive_partial_eq_without_eq)]
150+
#[derive(Clone, PartialEq, ::prost::Message)]
143151
pub struct List {
144152
#[prost(message, optional, boxed, tag = "1")]
145153
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
446454
pub struct ArrowType {
447455
#[prost(
448456
oneof = "arrow_type::ArrowTypeEnum",
449-
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
457+
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
450458
)]
451459
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
452460
}
@@ -516,6 +524,8 @@ pub mod arrow_type {
516524
Interval(i32),
517525
#[prost(message, tag = "24")]
518526
Decimal(super::Decimal),
527+
#[prost(message, tag = "36")]
528+
Decimal256(super::Decimal256Type),
519529
#[prost(message, tag = "25")]
520530
List(::prost::alloc::boxed::Box<super::List>),
521531
#[prost(message, tag = "26")]

datafusion/proto-common/src/to_proto/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,10 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum {
191191
precision: *precision as u32,
192192
scale: *scale as i32,
193193
}),
194-
DataType::Decimal256(_, _) => {
195-
return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned()))
196-
}
194+
DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type {
195+
precision: *precision as u32,
196+
scale: *scale as i32,
197+
}),
197198
DataType::Map(field, sorted) => {
198199
Self::Map(Box::new(
199200
protobuf::Map {

datafusion/proto/src/generated/datafusion_proto_common.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,14 @@ pub struct Decimal {
140140
}
141141
#[allow(clippy::derive_partial_eq_without_eq)]
142142
#[derive(Clone, PartialEq, ::prost::Message)]
143+
pub struct Decimal256Type {
144+
#[prost(uint32, tag = "3")]
145+
pub precision: u32,
146+
#[prost(int32, tag = "4")]
147+
pub scale: i32,
148+
}
149+
#[allow(clippy::derive_partial_eq_without_eq)]
150+
#[derive(Clone, PartialEq, ::prost::Message)]
143151
pub struct List {
144152
#[prost(message, optional, boxed, tag = "1")]
145153
pub field_type: ::core::option::Option<::prost::alloc::boxed::Box<Field>>,
@@ -446,7 +454,7 @@ pub struct Decimal256 {
446454
pub struct ArrowType {
447455
#[prost(
448456
oneof = "arrow_type::ArrowTypeEnum",
449-
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33"
457+
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33"
450458
)]
451459
pub arrow_type_enum: ::core::option::Option<arrow_type::ArrowTypeEnum>,
452460
}
@@ -516,6 +524,8 @@ pub mod arrow_type {
516524
Interval(i32),
517525
#[prost(message, tag = "24")]
518526
Decimal(super::Decimal),
527+
#[prost(message, tag = "36")]
528+
Decimal256(super::Decimal256Type),
519529
#[prost(message, tag = "25")]
520530
List(::prost::alloc::boxed::Box<super::List>),
521531
#[prost(message, tag = "26")]

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use arrow::array::{
2727
use arrow::datatypes::{
2828
DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType,
2929
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
30+
DECIMAL256_MAX_PRECISION,
3031
};
3132
use prost::Message;
3233

@@ -1379,6 +1380,7 @@ fn round_trip_datatype() {
13791380
DataType::Utf8,
13801381
DataType::LargeUtf8,
13811382
DataType::Decimal128(7, 12),
1383+
DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0),
13821384
// Recursive list tests
13831385
DataType::List(new_arc_field("Level1", DataType::Binary, true)),
13841386
DataType::List(new_arc_field(

0 commit comments

Comments
 (0)