Skip to content

Commit

Permalink
fix(BUX-686): applyJSONCondition
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-4chain committed Mar 28, 2024
1 parent 5e651e2 commit db804ee
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 46 deletions.
97 changes: 66 additions & 31 deletions engine/datastore/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,6 @@ func (builder *whereBuilder) applyCondition(tx customWhereInterface, key string,
tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)})
}

func (builder *whereBuilder) applyJson(tx customWhereInterface, key string, condition interface{}) {
columnName := builder.getColumnNameOrPanic(key)

varName := builder.nextVarName()
engine := builder.client.Engine()

if engine != PostgreSQL {
//todo handle other databases then postgres
panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet")
}

query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName)
tx.Where(query, map[string]interface{}{varName: builder.formatCondition(condition)})
}

func (builder *whereBuilder) applyExistsCondition(tx customWhereInterface, key string, condition bool) {
columnName := builder.getColumnNameOrPanic(key)

Expand Down Expand Up @@ -135,9 +120,9 @@ func (builder *whereBuilder) processConditions(tx customWhereInterface, conditio
} else if key == conditionExists {
builder.applyExistsCondition(tx, *parentKey, condition.(bool))
} else if StringInSlice(key, builder.client.GetArrayFields()) {
builder.applyArray(tx, key, condition)
builder.applyArray(tx, key, condition.(string))
} else if StringInSlice(key, builder.client.GetObjectFields()) {
builder.applyJson(tx, key, condition)
builder.applyJSONCondition(tx, key, condition)
} else {
if condition == nil {
builder.applyCondition(tx, key, "IS NULL", nil)
Expand Down Expand Up @@ -285,28 +270,78 @@ func (builder *whereBuilder) whereObject(k string, v interface{}) string {
}

// whereSlice generates the where slice
func (builder *whereBuilder) whereSlice(k string, v interface{}) string {
// func (builder *whereBuilder) whereSlice(k string, v interface{}) string {
// engine := builder.client.Engine()
// if engine == MySQL {
// return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))"
// } else if engine == PostgreSQL {
// return k + "::jsonb @> '[\"" + v.(string) + "\"]'"
// }
// return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")"
// }

func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition string) {
columnName := builder.getColumnNameOrPanic(key)

varName := builder.nextVarName()
engine := builder.client.Engine()
if engine == MySQL {
return "JSON_CONTAINS(" + k + ", CAST('[\"" + v.(string) + "\"]' AS JSON))"
} else if engine == PostgreSQL {
return k + "::jsonb @> '[\"" + v.(string) + "\"]'"

query := ""
arg := ""

switch engine {
case PostgreSQL:
query = fmt.Sprintf("%s::jsonb @> @%s", columnName, varName)
arg = fmt.Sprintf(`["%s"]`, condition)
case MySQL:
query = fmt.Sprintf("JSON_CONTAINS(%s, CAST(@%s AS JSON))", columnName, varName)
arg = fmt.Sprintf(`["%s"]`, condition)
case SQLite:
query = fmt.Sprintf("EXISTS (SELECT 1 FROM json_each(%s) WHERE value = @%s)", columnName, varName)
arg = condition
default:
panic("Database engine not supported")
}
return "EXISTS (SELECT 1 FROM json_each(" + k + ") WHERE value = \"" + v.(string) + "\")"

tx.Where(query, map[string]interface{}{varName: arg})
}

func (builder *whereBuilder) applyArray(tx customWhereInterface, key string, condition interface{}) {
func (builder *whereBuilder) applyJSONCondition(tx customWhereInterface, key string, condition interface{}) {
columnName := builder.getColumnNameOrPanic(key)

varName := builder.nextVarName()
engine := builder.client.Engine()

if engine != PostgreSQL {
//todo handle other databases then postgres
panic("eoeoeoeoeoeoeoeoeoeoeo not implemented yet")
if engine == PostgreSQL {
builder.applyJSONBCondition(tx, columnName, condition)
} else if engine == MySQL || engine == SQLite {
builder.applyJSONExtractCondition(tx, columnName, condition)
} else {
panic("Database engine not supported")
}
}

func (builder *whereBuilder) applyJSONBCondition(tx customWhereInterface, columnName string, condition interface{}) {
varName := builder.nextVarName()
query := fmt.Sprintf("%s::jsonb @> @%s", columnName, varName)
c := condition.(string)
tx.Where(query, map[string]interface{}{varName: builder.formatCondition("[\"" + c + "\"]")})
tx.Where(query, map[string]interface{}{varName: condition})
}

func (builder *whereBuilder) applyJSONExtractCondition(tx customWhereInterface, columnName string, condition interface{}) {
dict := convertTo[map[string]interface{}](condition)
for key, value := range dict {
keyVarName := builder.nextVarName()
valueVarName := builder.nextVarName()
query := fmt.Sprintf("JSON_EXTRACT(%s, @%s) = @%s", columnName, keyVarName, valueVarName)
tx.Where(query, map[string]interface{}{
keyVarName: fmt.Sprintf("$.%s", key),
valueVarName: value,
})
}
}

func convertTo[T any](object interface{}) T {
vJSON, _ := json.Marshal(object)

var converted T
_ = json.Unmarshal(vJSON, &converted)
return converted
}
141 changes: 126 additions & 15 deletions engine/datastore/where_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package datastore

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
customtypes "github.com/bitcoin-sv/spv-wallet/engine/datastore/customtypes"
"github.com/bitcoin-sv/spv-wallet/engine/utils"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)

func mockDialector(engine Engine) gorm.Dialector {
Expand Down Expand Up @@ -56,40 +63,144 @@ func makeWhereBuilder(client *Client, gdb *gorm.DB, model interface{}) *whereBui
}
}

const (
// MetadataField is the field name used for metadata (params)
MetadataField = "metadata"
)

type Metadata map[string]interface{}

func (m Metadata) GormDataType() string {
return "text"
}

func (m *Metadata) Scan(value interface{}) error {
if value == nil {
return nil
}

byteValue, err := utils.ToByteArray(value)
if err != nil || bytes.Equal(byteValue, []byte("")) || bytes.Equal(byteValue, []byte("\"\"")) {
return nil
}

return json.Unmarshal(byteValue, &m)
}

func (m Metadata) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
marshal, err := json.Marshal(m)
if err != nil {
return nil, err
}

return string(marshal), nil
}

func (Metadata) GormDBDataType(db *gorm.DB, _ *schema.Field) string {
if db.Dialector.Name() == Postgres {
return JSONB
}
return JSON
}

func (m *Metadata) MarshalBSONValue() (bsontype.Type, []byte, error) {
if m == nil || len(*m) == 0 {
return bson.TypeNull, nil, nil
}

metadata := make([]map[string]interface{}, 0)
for key, value := range *m {
metadata = append(metadata, map[string]interface{}{
"k": key,
"v": value,
})
}

return bson.MarshalValue(metadata)
}

func (m *Metadata) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
raw := bson.RawValue{Type: t, Value: data}

if raw.Value == nil {
return nil
}

var uMap []map[string]interface{}
if err := raw.Unmarshal(&uMap); err != nil {
return err
}

*m = make(Metadata)
for _, meta := range uMap {
key := meta["k"].(string)
(*m)[key] = meta["v"]
}

return nil
}

type mockObject struct {
ID string
CreatedAt time.Time
UniqueFieldName string
Number int
ReferenceID string
Metadata Metadata
}

// Test_whereObject test the SQL where selector
func Test_whereSlice(t *testing.T) {
t.Parallel()

t.Run("MySQL", func(t *testing.T) {
client, gdb := mockClient(MySQL)
builder := makeWhereBuilder(client, gdb, mockObject{})
query := builder.whereSlice(fieldInIDs, "id_1")
expected := `JSON_CONTAINS(` + fieldInIDs + `, CAST('["id_1"]' AS JSON))`
assert.Equal(t, expected, query)
})
conditions := map[string]interface{}{
"metadata": Metadata{
"domain": "test-domain",
},
}

t.Run("Postgres", func(t *testing.T) {
client, gdb := mockClient(PostgreSQL)
builder := makeWhereBuilder(client, gdb, mockObject{})
query := builder.whereSlice(fieldInIDs, "id_1")
expected := fieldInIDs + `::jsonb @> '["id_1"]'`
assert.Equal(t, expected, query)

raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{})
assert.NoError(t, err)
return tx.First(&mockObject{})
})

assert.Contains(t, raw, "metadata::jsonb @>")
assert.Contains(t, raw, `'{"domain":"test-domain"}'`)
})

t.Run("SQLite", func(t *testing.T) {
client, gdb := mockClient(SQLite)
builder := makeWhereBuilder(client, gdb, mockObject{})
query := builder.whereSlice(fieldInIDs, "id_1")
expected := `EXISTS (SELECT 1 FROM json_each(` + fieldInIDs + `) WHERE value = "id_1")`
assert.Equal(t, expected, query)

raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{})
assert.NoError(t, err)
return tx.First(&mockObject{})
})

assert.Contains(t, raw, "JSON_EXTRACT(metadata")
assert.Contains(t, raw, `"$.domain"`)
assert.Contains(t, raw, `"test-domain"`)
})

t.Run("MySQL", func(t *testing.T) {
client, gdb := mockClient(MySQL)

raw := gdb.ToSQL(func(tx *gorm.DB) *gorm.DB {
tx, err := ApplyCustomWhere(client, tx, conditions, mockObject{})
assert.NoError(t, err)
return tx.First(&mockObject{})
})

assert.Contains(t, raw, "JSON_EXTRACT(metadata")
assert.Contains(t, raw, "'$.domain'")
assert.Contains(t, raw, "'test-domain'")
})
}

Expand Down

0 comments on commit db804ee

Please # to comment.