Skip to content

Handle Serde for Custom ScalarUDFImpl traits #8706

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
alamb opened this issue Jan 1, 2024 · 13 comments · Fixed by #9395
Closed

Handle Serde for Custom ScalarUDFImpl traits #8706

alamb opened this issue Jan 1, 2024 · 13 comments · Fixed by #9395
Labels
enhancement New feature or request

Comments

@alamb
Copy link
Contributor

alamb commented Jan 1, 2024

Is your feature request related to a problem or challenge?

#8578 added a ScalarUDFImpl trait for implementing ScalarUDF.

@thinkharderdev said: #8578 (comment)

Nice! It would be very useful to be able to handle serde as well for custom implementations (perhaps in a different PR?). I think this could fit relatively easily into LogicalExtensionCodec

Describe the solution you'd like

No response

Describe alternatives you've considered

No response

Additional context

No response

@alamb alamb added the enhancement New feature or request label Jan 1, 2024
@yyy1000
Copy link
Contributor

yyy1000 commented Jan 24, 2024

I'd also like to work on this. 😃

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 25, 2024

I need some guide on this. 🤔
It seems ScalarUDF can handle Serde now, what would this PR want to implement?
https://github.com/apache/arrow-datafusion/blob/7a0af5be2323443faa75cc5876651a72c3253af8/datafusion/proto/src/logical_plan/from_proto.rs#L1818-L1826

@alamb
Copy link
Contributor Author

alamb commented Jan 25, 2024

Maybe @thinkharderdev can comment -- perhaps nothing is needed?

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 25, 2024

Aha, maybe I can help other issues first. 😄

@thinkharderdev
Copy link
Contributor

Hey @yyy1000 I think there is some work to do here. Currently the serialization for udf looks like

                    ScalarFunctionDefinition::UDF(fun) => Self {
                        expr_type: Some(ExprType::ScalarUdfExpr(
                            protobuf::ScalarUdfExprNode {
                                fun_name: fun.name().to_string(),
                                args,
                            },
                        )),
                    },

eg, we just use the name and assume wherever it is being deserialized will just have a registry where it can look up the scalar function definition by name.

But ideally we would be able to serialize a custom scalar function that has some sort of associated state. For example, a regex scalar function that actually contains the compiled regex in it's struct definition like:

struct MyRegexUdf {
   compiled_regex: Vec<u8> // just assume we have some serialization of the regex state machine here
}

impl ScalarUDFImpl for MyRegexUdf {
  fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
     // do something with compiled regex here
  }
}

Currently the mechanism that is used for this kind of thing is to define a custom LogicalExtensionCodec. So here I think we would add methods to that trait like

impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {

  fn try_encode_scalar_udf(
        &self,
        _node: Arc<dyn ScalarUdfImpl>,
        _buf: &mut Vec<u8>,
    ) -> Result<()> {
        not_impl_err!("LogicalExtensionCodec is not provided")
    };


   fn try_decode_scalar_udf(
        &self,
        _buf: &[u8],
        _ctx: &SessionContext,
    ) -> Result<Arc<dyn ScalarUDFImpl>> {
        not_impl_err!("LogicalExtensionCodec is not provided")
    }
}

So then I would be able to define my own UDFs that contain internal state and then define an extension codec like

struct MyLogicalExtensionCodec;

impl LogicalExentionCodec for MyLogicalExtensionCodec {
    fn try_encode_scalar_udf(
        &self,
        node: Arc<dyn ScalarUdfImpl>,
        buf: &mut Vec<u8>,
    ) -> Result<()> {
        if let Some(regex_udf) = node.as_any().downcast_ref::<MyRegexUdf> {
           let proto = MyRegexUdfProto {
             compiled_regex: regex.compiled_regex.clone()
           }

           proto.encode(buf)?;

           Ok(())
        } else {
           not_impl_err!("LogicalExtensionCodec is not provided")
        }
    }; 

   fn try_decode_scalar_udf(
        &self,
        buf: &[u8],
        _ctx: &SessionContext,
    ) -> Result<Arc<dyn ScalarUDFImpl>> {
        if let Ok(proto) = MyRegexUdfProto::decode(buf) {
           Ok(Arc::new(MyRegexUdf { compiled_regex: proto.compiled_regex)))
        } else {
           not_impl_err!("LogicalExtensionCodec is not provided")
        }
    }

}

However, this doesn't play very nicely with how the serde is currently defined because we have no way to get a LogicalExtensionCodec in our impl TryFrom<&Expr> for protobuf::LogicalExprNode which we would need

@thinkharderdev
Copy link
Contributor

Perhaps this could be incorporated into FunctionRegistry and we have a serialize_expr(expr: &Expr, registry: &dyn FunctionRegistry) similar to how we do the deserialization?

@yyy1000
Copy link
Contributor

yyy1000 commented Jan 29, 2024

@thinkharderdev Much appreciated! I understand the issue now.
Another question is there some example code for deserialization so that I can refer to them to think how to write serialize_expr function?

@thinkharderdev
Copy link
Contributor

thinkharderdev commented Feb 26, 2024

@thinkharderdev Much appreciated! I understand the issue now. Another question is there some example code for deserialization so that I can refer to them to think how to write serialize_expr function?

Edit: On second thought we can probably just extend LogicalExtensionCodec and PhysicalExtensionCodec

Hey @yyy1000, looking at this a bit and I think what we want here is:

  1. Have ScalarUDFExprNode take an optional opaque payload to potentially contain the serialized function:
message ScalarUDFExprNode {
  string fun_name = 1;
  repeated LogicalExprNode args = 2;
  optional bytes fun_definition = 3;
}

Would do a similar thing for AggregateUDFExprNode and WindowExprNode

  1. Extend LogicalExtensionCodec and PhysicalExtensionCodec to support custom serde on UDF/UDAF/UDWF
pub trait LogicalExtensionCodec: Debug + Send + Sync {
    ... existing methods unchanged

    fn try_decode_udf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<ScalarUDF>>;

    fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec<u8>) -> Result<()>;

    fn try_decode_udaf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<AggregateUDF>>;

    fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()>;

    fn try_decode_udwf(
        &self,
        name: &str,
        buf: &[u8],
    ) -> Result<Arc<WindowUDF>>;

    fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()>;
}
  1. Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function
pub fn serialize_expr(expr: &Expr, codec: &dyn LogicalExtensionCodec) -> Result<protobuf::LogicalExprNode> {
   ...
}

This would be mostly unchanged from the existing TryFrom implementation except for handling of Expr::ScalarFunction/AggregateFunction/WindowFunction. We would handle Expr::ScalarFunction something like:

            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
                let args = args
                    .iter()
                    .map(|expr| expr.try_into())
                    .collect::<Result<Vec<_>, Error>>()?;
                match func_def {
                    ScalarFunctionDefinition::BuiltIn(fun) => {
                        let fun: protobuf::ScalarFunction = fun.try_into()?;
                        Self {
                            expr_type: Some(ExprType::ScalarFunction(
                                protobuf::ScalarFunctionNode {
                                    fun: fun.into(),
                                    args,
                                },
                            )),
                        }
                    }
                    ScalarFunctionDefinition::UDF(fun) => {
                        let mut buf = Vec::new();
                        codec.try_encode_udf(fun.as_ref(), &mut buf)?;

                        let fun_definition = if buf.is_empty() {
                            None
                        } else {
                            Some(buf)
                        };

                        Self {
                            expr_type: Some(ExprType::ScalarUdfExpr(
                                protobuf::ScalarUdfExprNode {
                                    fun_name: fun.name().to_string(),
                                    fun_definition,
                                    args,
                                },
                            )),
                        }
                    },
                    ScalarFunctionDefinition::Name(_) => {
                        return Err(Error::NotImplemented(
                    "Proto serialization error: Trying to serialize a unresolved function"
                        .to_string(),
                ));
                    }
                }
            }

  1. Similarly, in existing parse_expr we try and use the extension codec to deserialize the function if fun_definition is present:
pub fn parse_expr(
    proto: &protobuf::LogicalExprNode,
    registry: &dyn FunctionRegistry,
    codec: &dyn LogicalExtensionCodec,
) -> Result<Expr, Error> {
  ... handling of other expr types unchanged
  
  ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, fun_definition, args }) => {
  
     let scalar_fn = match fun_definition {
        Some(buf) => codec.try_decode_udf(&fun_name, &buf)?,
        None => registry.udf(fun_name.as_str())?,
     };
      Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
            scalar_fn,
            args.iter()
                    .map(|expr| parse_expr(expr, registry))
                    .collect::<Result<Vec<_>, Error>>()?,
            )))
     }

}

If you don't have bandwidth now to work on this let me know, My team can take this up as we are hoping to be able to use this functionality soon.

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

Thanks so much for your detailed instructions! @thinkharderdev
I can try to implement it now and will seek your help if needed :)

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

@thinkharderdev
Copy link
Contributor

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

Yeah I think so. @alamb do you see any issues with that?

@yyy1000
Copy link
Contributor

yyy1000 commented Feb 27, 2024

Also I wonder how to deal with some other places which don't need a LogicalExtensionCodec like to_bytes
https://github.com/apache/arrow-datafusion/blob/c439bc73b6a9ba9efa4c8a9b5d2fb6111e660e74/datafusion/proto/src/bytes/mod.rs#L87-L92. I think add a param of type LogicalExtensionCodec to the function will work, and what to do when calling this method, should I initialize a impl LogicalExtensionCodec? 🤔

@alamb
Copy link
Contributor Author

alamb commented Feb 28, 2024

A question I have is, replace Replace impl TryFrom<&Expr> for protobuf::LogicalExprNode with a free function may lead a lot of places like expr.try_into() to serialize_expr(expr, codec), would this be OK for the code base? @thinkharderdev

Yeah I think so. @alamb do you see any issues with that?

I don't see any specific issue and I don't think it would affect users of the crate much -- I don't think they typically use protobuf encoding directly, but rather go through the higher level apis like Expr::to_bytes():

https://github.com/apache/arrow-datafusion/blob/e62240969135e2236d100c8c0c01546a87950a80/datafusion/proto/src/lib.rs#L52-L68

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants