From 80f57d1c24f6a7d8fadb11a3271a4d3f26e8d826 Mon Sep 17 00:00:00 2001 From: CyJaySong <29367599+cyjaysong@users.noreply.github.com> Date: Thu, 26 Dec 2024 18:18:35 +0800 Subject: [PATCH] fix(database/gdb): `gdb.Counter` not work in `OnDuplicate` (#4073) --- .../drivers/mysql/mysql_z_unit_model_test.go | 22 +++++ contrib/drivers/pgsql/pgsql_format_upsert.go | 19 ++++ .../drivers/pgsql/pgsql_z_unit_model_test.go | 22 +++++ .../drivers/sqlite/sqlite_format_upsert.go | 19 ++++ .../sqlite/sqlite_z_unit_model_test.go | 22 +++++ .../drivers/sqlitecgo/sqlite_format_upsert.go | 90 +++++++++++++++++++ .../drivers/sqlitecgo/sqlitecgo_do_filter.go | 10 --- contrib/drivers/sqlitecgo/sqlitecgo_tables.go | 10 +-- .../sqlitecgo/sqlitecgo_z_unit_init_test.go | 5 -- .../sqlitecgo/sqlitecgo_z_unit_model_test.go | 24 ++++- database/gdb/gdb_core.go | 51 +++++------ database/gdb/gdb_core_underlying.go | 16 ++++ 12 files changed, 264 insertions(+), 46 deletions(-) create mode 100644 contrib/drivers/sqlitecgo/sqlite_format_upsert.go diff --git a/contrib/drivers/mysql/mysql_z_unit_model_test.go b/contrib/drivers/mysql/mysql_z_unit_model_test.go index 9be10ff1313..d28f461c241 100644 --- a/contrib/drivers/mysql/mysql_z_unit_model_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_model_test.go @@ -2812,6 +2812,28 @@ func Test_Model_OnDuplicate(t *testing.T) { }) } +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} + func Test_Model_OnDuplicateEx(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/contrib/drivers/pgsql/pgsql_format_upsert.go b/contrib/drivers/pgsql/pgsql_format_upsert.go index c4c8af91122..fc003cb4ce5 100644 --- a/contrib/drivers/pgsql/pgsql_format_upsert.go +++ b/contrib/drivers/pgsql/pgsql_format_upsert.go @@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse d.Core.QuoteWord(k), v, ) + case gdb.Counter, *gdb.Counter: + var counter gdb.Counter + switch value := v.(type) { + case gdb.Counter: + counter = value + case *gdb.Counter: + counter = *value + } + operator, columnVal := "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s%s%s", + d.QuoteWord(k), + d.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=EXCLUDED.%s", diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index d7748f07f17..3a51ba264dc 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -521,6 +521,28 @@ func Test_Model_OnDuplicate(t *testing.T) { }) } +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} + func Test_Model_OnDuplicateEx(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/contrib/drivers/sqlite/sqlite_format_upsert.go b/contrib/drivers/sqlite/sqlite_format_upsert.go index 5821144a13e..c80d5dfd74e 100644 --- a/contrib/drivers/sqlite/sqlite_format_upsert.go +++ b/contrib/drivers/sqlite/sqlite_format_upsert.go @@ -40,6 +40,25 @@ func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInse d.Core.QuoteWord(k), v, ) + case gdb.Counter, *gdb.Counter: + var counter gdb.Counter + switch value := v.(type) { + case gdb.Counter: + counter = value + case *gdb.Counter: + counter = *value + } + operator, columnVal := "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s%s%s", + d.QuoteWord(k), + d.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=EXCLUDED.%s", diff --git a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go index 19e97bfa51a..03e8465c7fc 100644 --- a/contrib/drivers/sqlite/sqlite_z_unit_model_test.go +++ b/contrib/drivers/sqlite/sqlite_z_unit_model_test.go @@ -4324,3 +4324,25 @@ func Test_OrderRandom(t *testing.T) { t.Assert(len(result), TableSize) }) } + +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} diff --git a/contrib/drivers/sqlitecgo/sqlite_format_upsert.go b/contrib/drivers/sqlitecgo/sqlite_format_upsert.go new file mode 100644 index 00000000000..63cb33e5a06 --- /dev/null +++ b/contrib/drivers/sqlitecgo/sqlite_format_upsert.go @@ -0,0 +1,90 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package sqlitecgo + +import ( + "fmt" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// FormatUpsert returns SQL clause of type upsert for SQLite. +// For example: ON CONFLICT (id) DO UPDATE SET ... +func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { + if len(option.OnConflict) == 0 { + return "", gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, + ) + } + + var onDuplicateStr string + if option.OnDuplicateStr != "" { + onDuplicateStr = option.OnDuplicateStr + } else if len(option.OnDuplicateMap) > 0 { + for k, v := range option.OnDuplicateMap { + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + switch v.(type) { + case gdb.Raw, *gdb.Raw: + onDuplicateStr += fmt.Sprintf( + "%s=%s", + d.Core.QuoteWord(k), + v, + ) + case gdb.Counter, *gdb.Counter: + var counter gdb.Counter + switch value := v.(type) { + case gdb.Counter: + counter = value + case *gdb.Counter: + counter = *value + } + operator, columnVal := "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s%s%s", + d.QuoteWord(k), + d.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) + default: + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(k), + d.Core.QuoteWord(gconv.String(v)), + ) + } + } + } else { + for _, column := range columns { + // If it's SAVE operation, do not automatically update the creating time. + if d.Core.IsSoftCreatedFieldName(column) { + continue + } + if len(onDuplicateStr) > 0 { + onDuplicateStr += "," + } + onDuplicateStr += fmt.Sprintf( + "%s=EXCLUDED.%s", + d.Core.QuoteWord(column), + d.Core.QuoteWord(column), + ) + } + } + + conflictKeys := gstr.Join(option.OnConflict, ",") + + return fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET ", conflictKeys) + onDuplicateStr, nil +} diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_do_filter.go b/contrib/drivers/sqlitecgo/sqlitecgo_do_filter.go index 6c578225e53..476ae1a0e97 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_do_filter.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_do_filter.go @@ -10,8 +10,6 @@ import ( "context" "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gstr" ) @@ -26,14 +24,6 @@ func (d *Driver) DoFilter( case gstr.HasPrefix(sql, gdb.InsertOperationReplace): sql = "INSERT OR REPLACE" + sql[len(gdb.InsertOperationReplace):] - - default: - if gstr.Contains(sql, gdb.InsertOnDuplicateKeyUpdate) { - return sql, args, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by sqlite driver`, - ) - } } return d.Core.DoFilter(ctx, link, sql, args) } diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_tables.go b/contrib/drivers/sqlitecgo/sqlitecgo_tables.go index d985c3a62af..993c6cf8bdc 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_tables.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_tables.go @@ -12,6 +12,10 @@ import ( "github.com/gogf/gf/v2/database/gdb" ) +const ( + tablesSqlTmp = `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME` +) + // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, err error) { @@ -21,11 +25,7 @@ func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, return nil, err } - result, err = d.DoSelect( - ctx, - link, - `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`, - ) + result, err = d.DoSelect(ctx, link, tablesSqlTmp) if err != nil { return } diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_init_test.go b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_init_test.go index fe778b93f8f..0dd2aa0dc76 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_init_test.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_init_test.go @@ -11,8 +11,6 @@ import ( "github.com/gogf/gf/v2/container/garray" "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gfile" @@ -27,9 +25,6 @@ var ( configNode gdb.ConfigNode dbDir = gfile.Temp("sqlite") ctx = gctx.New() - - // Error - ErrorSave = gerror.NewCode(gcode.CodeNotSupported, `Save operation is not supported by sqlite driver`) ) const ( diff --git a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go index 613ef9c64b3..3ad793df2d3 100644 --- a/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go +++ b/contrib/drivers/sqlitecgo/sqlitecgo_z_unit_model_test.go @@ -375,7 +375,7 @@ func Test_Model_Save(t *testing.T) { "nickname": "oldme", "create_time": CreateTime, }).OnConflict("id").Save() - t.Assert(err, ErrorSave) + t.AssertNil(err) }) } @@ -4361,3 +4361,25 @@ func TestResult_Structs1(t *testing.T) { t.Assert(array[1].Name, "smith") }) } + +func Test_Model_OnDuplicateWithCounter(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": "pp1", + "password": "pw1", + "nickname": "n1", + "create_time": "2016-06-06", + } + _, err := db.Model(table).OnConflict("id").OnDuplicate(g.Map{ + "id": gdb.Counter{Field: "id", Value: 999999}, + }).Data(data).Save() + t.AssertNil(err) + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.AssertNil(one) + }) +} diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 50971611258..f4faa0bf4f9 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -583,24 +583,8 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter switch kind { case reflect.Map, reflect.Struct: var ( - fields []string - dataMap map[string]interface{} - counterHandler = func(column string, counter Counter) { - if counter.Value != 0 { - column = c.QuoteWord(column) - var ( - columnRef = c.QuoteWord(counter.Field) - columnVal = counter.Value - operator = "+" - ) - if columnVal < 0 { - operator = "-" - columnVal = -columnVal - } - fields = append(fields, fmt.Sprintf("%s=%s%s?", column, columnRef, operator)) - params = append(params, columnVal) - } - } + fields []string + dataMap map[string]interface{} ) dataMap, err = c.ConvertDataForRecord(ctx, data, table) if err != nil { @@ -620,13 +604,21 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter } for _, k := range keysInSequence { v := dataMap[k] - switch value := v.(type) { - case *Counter: - counterHandler(k, *value) - - case Counter: - counterHandler(k, value) - + switch v.(type) { + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + if counter.Value == 0 { + continue + } + operator, columnVal := c.getCounterAlter(counter) + fields = append(fields, fmt.Sprintf("%s=%s%s?", c.QuoteWord(k), c.QuoteWord(counter.Field), operator)) + params = append(params, columnVal) default: if s, ok := v.(Raw); ok { fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s)) @@ -796,3 +788,12 @@ func (c *Core) IsSoftCreatedFieldName(fieldName string) bool { func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) { return handleSliceAndStructArgsForSql(sql, args) } + +// getCounterAlter +func (c *Core) getCounterAlter(counter Counter) (operator string, columnVal float64) { + operator, columnVal = "+", counter.Value + if columnVal < 0 { + operator, columnVal = "-", -columnVal + } + return +} diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 045d11c65af..25c60a4baf7 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -388,6 +388,22 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) c.QuoteWord(k), v, ) + case Counter, *Counter: + var counter Counter + switch value := v.(type) { + case Counter: + counter = value + case *Counter: + counter = *value + } + operator, columnVal := c.getCounterAlter(counter) + onDuplicateStr += fmt.Sprintf( + "%s=%s%s%s", + c.QuoteWord(k), + c.QuoteWord(counter.Field), + operator, + gconv.String(columnVal), + ) default: onDuplicateStr += fmt.Sprintf( "%s=VALUES(%s)",