Skip to content

Commit

Permalink
expression: add missing setPbCode() for some arithmetic function (pin…
Browse files Browse the repository at this point in the history
  • Loading branch information
lonng authored and XiaTianliang committed Dec 21, 2019
1 parent 9b3f4f6 commit 238c0a2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,13 +670,15 @@ func (c *arithmeticIntDivideFunctionClass) getFunction(ctx sessionctx.Context, a
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IntDivideInt)
return sig, nil
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETDecimal, types.ETDecimal)
if mysql.HasUnsignedFlag(lhsTp.Flag) || mysql.HasUnsignedFlag(rhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticIntDivideDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_IntDivideDecimal)
return sig, nil
}

Expand Down Expand Up @@ -834,6 +836,7 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args []
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModReal)
return sig, nil
} else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal)
Expand All @@ -842,13 +845,15 @@ func (c *arithmeticModFunctionClass) getFunction(ctx sessionctx.Context, args []
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModDecimal)
return sig, nil
} else {
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
if mysql.HasUnsignedFlag(lhsTp.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ModInt)
return sig, nil
}
}
Expand Down
15 changes: 15 additions & 0 deletions expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tipb/go-tipb"
)

func (s *testEvaluatorSuite) TestSetFlenDecimal4RealOrDecimal(c *C) {
Expand Down Expand Up @@ -376,6 +377,12 @@ func (s *testEvaluatorSuite) TestArithmeticDivide(c *C) {
sig, err := funcs[ast.Div].getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
switch sig.(type) {
case *builtinArithmeticIntDivideIntSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_IntDivideInt)
case *builtinArithmeticIntDivideDecimalSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_IntDivideDecimal)
}
val, err := evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
Expand Down Expand Up @@ -601,6 +608,14 @@ func (s *testEvaluatorSuite) TestArithmeticMod(c *C) {
c.Assert(err, IsNil)
c.Assert(sig, NotNil)
val, err := evalBuiltinFunc(sig, chunk.Row{})
switch sig.(type) {
case *builtinArithmeticModRealSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModReal)
case *builtinArithmeticModIntSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModInt)
case *builtinArithmeticModDecimalSig:
c.Assert(sig.PbCode(), Equals, tipb.ScalarFuncSig_ModDecimal)
}
c.Assert(err, IsNil)
c.Assert(val, testutil.DatumEquals, types.NewDatum(tc.expect))
}
Expand Down

0 comments on commit 238c0a2

Please # to comment.