diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index f2524dd0792fa..77ae11d93d7de 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -670,6 +670,7 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx sessionctx.Context, a bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticIntDivideIntSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IntDivideInt) return sig, nil } bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETDecimal, types.ETDecimal) @@ -677,6 +678,7 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx sessionctx.Context, a bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticIntDivideDecimalSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_IntDivideDecimal) return sig, nil } @@ -834,6 +836,7 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args [] bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModRealSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ModReal) return sig, nil } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) @@ -842,6 +845,7 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args [] bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModDecimalSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ModDecimal) return sig, nil } else { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) @@ -849,6 +853,7 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args [] bf.tp.Flag |= mysql.UnsignedFlag } sig := &builtinArithmeticModIntSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ModInt) return sig, nil } } diff --git a/expression/builtin_arithmetic_test.go b/expression/builtin_arithmetic_test.go index a1e5afc63180d..dfc8af3d8f364 100644 --- a/expression/builtin_arithmetic_test.go +++ b/expression/builtin_arithmetic_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" + "github.com/pingcap/tipb/go-tipb" ) func (s *testEvaluatorSuite) TestSetFlenDecimal4RealOrDecimal(c *C) { @@ -376,6 +377,12 @@ func (s *testEvaluatorSuite) TestArithmeticDivide(c *C) { sig, err := funcs[ast.Div].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...))) c.Assert(err, IsNil) c.Assert(sig, NotNil) + switch sig.(type) { + case *builtinArithmeticIntDivideIntSig: + c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_IntDivideInt) + case *builtinArithmeticIntDivideDecimalSig: + c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_IntDivideDecimal) + } val, err := evalBuiltinFunc(sig, chunk.Row{}) c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect)) @@ -601,6 +608,14 @@ func (s *testEvaluatorSuite) TestArithmeticMod(c *C) { c.Assert(err, IsNil) c.Assert(sig, NotNil) val, err := evalBuiltinFunc(sig, chunk.Row{}) + switch sig.(type) { + case *builtinArithmeticModRealSig: + c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModReal) + case *builtinArithmeticModIntSig: + c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModInt) + case *builtinArithmeticModDecimalSig: + c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModDecimal) + } c.Assert(err, IsNil) c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect)) }