diff --git a/pkg/planner/core/operator/logicalop/logical_apply.go b/pkg/planner/core/operator/logicalop/logical_apply.go index 7ba12a78189f0..6420c2ffe4cc4 100644 --- a/pkg/planner/core/operator/logicalop/logical_apply.go +++ b/pkg/planner/core/operator/logicalop/logical_apply.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + base2 "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/base" fd "github.com/pingcap/tidb/pkg/planner/funcdep" "github.com/pingcap/tidb/pkg/planner/property" @@ -45,6 +46,39 @@ func (la LogicalApply) Init(ctx base.PlanContext, offset int) *LogicalApply { return &la } +// *************************** start implementation of HashEquals interface **************************** + +// Hash64 implements the base.Hash64.<0th> interface. +func (la *LogicalApply) Hash64(h base2.Hasher) { + la.LogicalJoin.Hash64(h) + h.HashInt(len(la.CorCols)) + for _, one := range la.CorCols { + one.Hash64(h) + } + h.HashBool(la.NoDecorrelate) +} + +// Equals implements the base.HashEquals.<1st> interface. +func (la *LogicalApply) Equals(other any) bool { + if other == nil { + return false + } + la2, ok := other.(*LogicalApply) + if !ok { + return false + } + ok = la.LogicalJoin.Equals(&la2.LogicalJoin) && len(la.CorCols) == len(la2.CorCols) && la.NoDecorrelate == la2.NoDecorrelate + if !ok { + return false + } + for i, one := range la.CorCols { + if !one.Equals(la2.CorCols[i]) { + return false + } + } + return true +} + // *************************** start implementation of Plan interface *************************** // ExplainInfo implements Plan interface. diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel index 029c61703ab86..7695efef4251c 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel +++ b/pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "logical_mem_table_predicate_extractor_test.go", ], flaky = True, - shard_count = 15, + shard_count = 16, deps = [ "//pkg/domain", "//pkg/expression", diff --git a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go index f1139cfb9d893..87cb7b9cdf5a3 100644 --- a/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go +++ b/pkg/planner/core/operator/logicalop/logicalop_test/hash64_equals_test.go @@ -28,6 +28,60 @@ import ( "github.com/stretchr/testify/require" ) +func TestLogicalApplyHash64Equals(t *testing.T) { + col1 := &expression.Column{ + ID: 1, + Index: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + ID: 2, + Index: 1, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col3 := &expression.Column{ + ID: 3, + Index: 2, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + ctx := mock.NewContext() + eq, err := expression.NewFunction(ctx, ast.EQ, types.NewFieldType(mysql.TypeLonglong), col1, col2) + require.Nil(t, err) + join := &logicalop.LogicalJoin{ + JoinType: logicalop.InnerJoin, + EqualConditions: []*expression.ScalarFunction{eq.(*expression.ScalarFunction)}, + } + la1 := logicalop.LogicalApply{ + LogicalJoin: *join, + CorCols: []*expression.CorrelatedColumn{{Column: *col3}}, + } + la2 := logicalop.LogicalApply{ + LogicalJoin: *join, + CorCols: []*expression.CorrelatedColumn{{Column: *col3}}, + } + hasher1 := base.NewHashEqualer() + hasher2 := base.NewHashEqualer() + la1.Hash64(hasher1) + la2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) + + la2.CorCols = []*expression.CorrelatedColumn{{Column: *col2}} + hasher2.Reset() + la2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + la2.CorCols = []*expression.CorrelatedColumn{{Column: *col3}} + la2.NoDecorrelate = true + hasher2.Reset() + la2.Hash64(hasher2) + require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64()) + + la2.NoDecorrelate = false + hasher2.Reset() + la2.Hash64(hasher2) + require.Equal(t, hasher1.Sum64(), hasher2.Sum64()) +} + func TestLogicalJoinHash64Equals(t *testing.T) { col1 := &expression.Column{ ID: 1,