diff --git a/internal/core/src/common/FieldMeta.cpp b/internal/core/src/common/FieldMeta.cpp index 9b6056bc46e0a..ca55d45e67ac2 100644 --- a/internal/core/src/common/FieldMeta.cpp +++ b/internal/core/src/common/FieldMeta.cpp @@ -20,7 +20,7 @@ namespace milvus { TokenizerParams ParseTokenizerParams(const TypeParams& params) { - auto iter = params.find("analyzer_params"); + auto iter = params.find("tokenizer_params"); if (iter == params.end()) { return {}; } @@ -47,9 +47,20 @@ FieldMeta::enable_match() const { return string_info_->enable_match; } +bool +FieldMeta::enable_tokenizer() const { + if (!IsStringDataType(type_)) { + return false; + } + if (!string_info_.has_value()) { + return false; + } + return string_info_->enable_tokenizer; +} + TokenizerParams FieldMeta::get_tokenizer_params() const { - Assert(enable_match()); + Assert(enable_tokenizer()); auto params = string_info_->params; return ParseTokenizerParams(params); } @@ -91,29 +102,32 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) { auto type_map = RepeatedKeyValToMap(schema_proto.type_params()); AssertInfo(type_map.count(MAX_LENGTH), "max_length not found"); auto max_len = boost::lexical_cast(type_map.at(MAX_LENGTH)); - bool enable_match = false; - if (type_map.count("enable_match")) { - auto param_str = type_map.at("enable_match"); + + auto get_bool_value = [&](const std::string& key) -> bool { + if (!type_map.count(key)) { + return false; + } + auto param_str = type_map.at(key); std::transform(param_str.begin(), param_str.end(), param_str.begin(), ::tolower); + std::istringstream ss(param_str); + bool b; + ss >> std::boolalpha >> b; + return b; + }; - auto bool_cast = [](const std::string& arg) -> bool { - std::istringstream ss(arg); - bool b; - ss >> std::boolalpha >> b; - return b; - }; + bool enable_tokenizer = get_bool_value("enable_tokenizer"); + bool enable_match = get_bool_value("enable_match"); - enable_match = bool_cast(param_str); - } return FieldMeta{name, field_id, data_type, max_len, nullable, enable_match, + enable_tokenizer, type_map}; } diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index ac965bbe7bce2..ed040902a54d6 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -64,11 +64,13 @@ class FieldMeta { int64_t max_length, bool nullable, bool enable_match, + bool enable_tokenizer, std::map& params) : name_(name), id_(id), type_(type), - string_info_(StringInfo{max_length, enable_match, std::move(params)}), + string_info_(StringInfo{ + max_length, enable_match, enable_tokenizer, std::move(params)}), nullable_(nullable) { Assert(IsStringDataType(type_)); } @@ -122,6 +124,9 @@ class FieldMeta { bool enable_match() const; + bool + enable_tokenizer() const; + TokenizerParams get_tokenizer_params() const; @@ -198,6 +203,7 @@ class FieldMeta { struct StringInfo { int64_t max_length; bool enable_match; + bool enable_tokenizer; std::map params; }; FieldName name_; diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index fb585923a0fd1..6a610fc4691d7 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -121,9 +121,16 @@ class Schema { int64_t max_length, bool nullable, bool enable_match, + bool enable_tokenizer, std::map& params) { - auto field_meta = FieldMeta( - name, id, data_type, max_length, nullable, enable_match, params); + auto field_meta = FieldMeta(name, + id, + data_type, + max_length, + nullable, + enable_match, + enable_tokenizer, + params); this->AddField(std::move(field_meta)); } diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs index 9a1d34b2476f8..f031789c84109 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/tokenizer.rs @@ -33,7 +33,7 @@ pub(crate) fn create_tokenizer(params: &HashMap) -> Option { - info!("no tokenizer is specific, use default tokenizer"); + info!("no tokenizer is specified, using default tokenizer"); Some(default_tokenizer()) } } diff --git a/internal/core/unittest/test_c_tokenizer.cpp b/internal/core/unittest/test_c_tokenizer.cpp index 2d47809f693d4..7e5c9e2a40df6 100644 --- a/internal/core/unittest/test_c_tokenizer.cpp +++ b/internal/core/unittest/test_c_tokenizer.cpp @@ -31,7 +31,7 @@ TEST(ValidateTextSchema, JieBa) { milvus::proto::schema::FieldSchema schema; { auto kv = schema.add_type_params(); - kv->set_key("analyzer_params"); + kv->set_key("tokenizer_params"); kv->set_value(R"({"tokenizer": "jieba"})"); } diff --git a/internal/core/unittest/test_text_match.cpp b/internal/core/unittest/test_text_match.cpp index c4de7beda03b4..55b85cad1d118 100644 --- a/internal/core/unittest/test_text_match.cpp +++ b/internal/core/unittest/test_text_match.cpp @@ -40,6 +40,7 @@ GenTestSchema(std::map params = {}) { 65536, false, true, + true, params); schema->AddField(std::move(f)); } @@ -76,14 +77,14 @@ TEST(ParseJson, Naive) { } } -TEST(ParseTokenizerParams, NoAnalyzerParams) { +TEST(ParseTokenizerParams, NoTokenizerParams) { TypeParams params{{"k", "v"}}; auto p = ParseTokenizerParams(params); ASSERT_EQ(0, p.size()); } TEST(ParseTokenizerParams, Default) { - TypeParams params{{"analyzer_params", R"({"tokenizer": "default"})"}}; + TypeParams params{{"tokenizer_params", R"({"tokenizer": "default"})"}}; auto p = ParseTokenizerParams(params); ASSERT_EQ(1, p.size()); auto iter = p.find("tokenizer"); @@ -251,7 +252,8 @@ TEST(TextMatch, SealedNaive) { TEST(TextMatch, GrowingJieBa) { auto schema = GenTestSchema({ {"enable_match", "true"}, - {"analyzer_params", R"({"tokenizer": "jieba"})"}, + {"enable_tokenizer", "true"}, + {"tokenizer_params", R"({"tokenizer": "jieba"})"}, }); auto seg = CreateGrowingSegment(schema, empty_index_meta); std::vector raw_str = {"青铜时代", "黄金时代"}; @@ -327,7 +329,8 @@ TEST(TextMatch, GrowingJieBa) { TEST(TextMatch, SealedJieBa) { auto schema = GenTestSchema({ {"enable_match", "true"}, - {"analyzer_params", R"({"tokenizer": "jieba"})"}, + {"enable_tokenizer", "true"}, + {"tokenizer_params", R"({"tokenizer": "jieba"})"}, }); auto seg = CreateSealedSegment(schema, empty_index_meta); std::vector raw_str = {"青铜时代", "黄金时代"}; diff --git a/internal/datacoord/job_manager_test.go b/internal/datacoord/job_manager_test.go index 99a49d149e67d..7eb3ac024d9d9 100644 --- a/internal/datacoord/job_manager_test.go +++ b/internal/datacoord/job_manager_test.go @@ -56,6 +56,9 @@ func (s *jobManagerSuite) TestJobManager_triggerStatsTaskLoop() { { Key: "enable_match", Value: "true", }, + { + Key: "enable_tokenizer", Value: "true", + }, }, }, }, diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go b/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go index 08d8b743becbf..2bb547f118adf 100644 --- a/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/flowgraph" @@ -44,6 +45,12 @@ func TestEmbeddingNode_BM25_Operator(t *testing.T) { Name: "text", FieldID: 101, DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "enable_tokenizer", + Value: "true", + }, + }, }, { Name: "sparse", FieldID: 102, diff --git a/internal/metastore/kv/rootcoord/kv_catalog_test.go b/internal/metastore/kv/rootcoord/kv_catalog_test.go index 44ee262970588..7dfbc27b55240 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog_test.go +++ b/internal/metastore/kv/rootcoord/kv_catalog_test.go @@ -1243,9 +1243,31 @@ func TestCatalog_CreateCollection(t *testing.T) { Partitions: []*model.Partition{ {PartitionName: "test"}, }, - Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}}, - Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}}, - State: pb.CollectionState_CollectionCreating, + Fields: []*model.Field{ + { + Name: "text", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "enable_tokenizer", + Value: "true", + }, + }, + }, + { + Name: "sparse", + DataType: schemapb.DataType_SparseFloatVector, + }, + }, + Functions: []*model.Function{ + { + Name: "test", + Type: schemapb.FunctionType_BM25, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"sparse"}, + }, + }, + State: pb.CollectionState_CollectionCreating, } err := kc.CreateCollection(ctx, coll, 100) assert.NoError(t, err) @@ -1325,9 +1347,31 @@ func TestCatalog_DropCollection(t *testing.T) { Partitions: []*model.Partition{ {PartitionName: "test"}, }, - Fields: []*model.Field{{Name: "text", DataType: schemapb.DataType_VarChar}, {Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}}, - Functions: []*model.Function{{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"text"}, OutputFieldNames: []string{"sparse"}}}, - State: pb.CollectionState_CollectionDropping, + Fields: []*model.Field{ + { + Name: "text", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "enable_tokenizer", + Value: "true", + }, + }, + }, + { + Name: "sparse", + DataType: schemapb.DataType_SparseFloatVector, + }, + }, + Functions: []*model.Function{ + { + Name: "test", + Type: schemapb.FunctionType_BM25, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"sparse"}, + }, + }, + State: pb.CollectionState_CollectionDropping, } err := kc.DropCollection(ctx, coll, 100) assert.NoError(t, err) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index c9dacdef07896..e60d046f2d50c 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3116,6 +3116,10 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { Key: "max_length", Value: strconv.Itoa(testMaxVarCharLength), }, + { + Key: "enable_tokenizer", + Value: "true", + }, }, } floatVecField := &schemapb.FieldSchema{ diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 498370c0e7666..da0ea00cb89db 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -697,6 +697,10 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema return fmt.Errorf("only one VARCHAR input field is allowed for a BM25 Function, got %d field with type %s", len(fields), fields[0].DataType.String()) } + h := typeutil.CreateFieldSchemaHelper(fields[0]) + if !h.EnableTokenizer() { + return fmt.Errorf("BM25 input field must set enable_tokenizer to true") + } default: return fmt.Errorf("check input field with unknown function type") @@ -739,7 +743,7 @@ func checkFunctionParams(function *schemapb.FunctionSchema) error { return fmt.Errorf("bm25_avgdl must large than zero but now %f", avgdl) } - case "analyzer_params": + case "tokenizer_params": // TODO ADD tokenizer check default: return fmt.Errorf("invalid function params, key: %s, value:%s", kv.GetKey(), kv.GetValue()) diff --git a/internal/util/ctokenizer/text_schema_validator.go b/internal/util/ctokenizer/text_schema_validator.go index fa6085a345b51..14d27b1bb9893 100644 --- a/internal/util/ctokenizer/text_schema_validator.go +++ b/internal/util/ctokenizer/text_schema_validator.go @@ -23,6 +23,10 @@ func ValidateTextSchema(fieldSchema *schemapb.FieldSchema) error { return nil } + if !h.EnableTokenizer() { + return fmt.Errorf("field %s is set to enable match but not enable tokenizer", fieldSchema.Name) + } + bs, err := proto.Marshal(fieldSchema) if err != nil { return fmt.Errorf("failed to marshal field schema: %w", err) diff --git a/internal/util/ctokenizer/text_schema_validator_test.go b/internal/util/ctokenizer/text_schema_validator_test.go index 104e0a9934924..56e3ba668c5cb 100644 --- a/internal/util/ctokenizer/text_schema_validator_test.go +++ b/internal/util/ctokenizer/text_schema_validator_test.go @@ -1,6 +1,7 @@ package ctokenizer import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -9,71 +10,64 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -func TestValidateTextSchema(t *testing.T) { - type args struct { - fieldSchema *schemapb.FieldSchema +func TestValidateEmptyTextSchema(t *testing.T) { + fs := &schemapb.FieldSchema{ + FieldID: 101, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{}, } - tests := []struct { - name string - args args - errIsNil bool - }{ - { - args: args{ - fieldSchema: &schemapb.FieldSchema{ - FieldID: 101, - TypeParams: []*commonpb.KeyValuePair{}, - }, - }, - errIsNil: true, - }, + assert.Nil(t, ValidateTextSchema(fs)) +} + +func TestValidateTextSchema(t *testing.T) { + tests := []*schemapb.FieldSchema{ { - // default - args: args{ - fieldSchema: &schemapb.FieldSchema{ - FieldID: 101, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "enable_match", Value: "true"}, - }, - }, + FieldID: 101, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "enable_match", Value: "true"}, }, - errIsNil: true, }, { - // default - args: args{ - fieldSchema: &schemapb.FieldSchema{ - FieldID: 101, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "enable_match", Value: "true"}, - {Key: "analyzer_params", Value: `{"tokenizer": "default"}`}, - }, - }, + FieldID: 101, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "enable_match", Value: "true"}, + {Key: "tokenizer_params", Value: `{"tokenizer": "default"}`}, }, - errIsNil: true, }, { - // jieba - args: args{ - fieldSchema: &schemapb.FieldSchema{ - FieldID: 101, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "enable_match", Value: "true"}, - {Key: "analyzer_params", Value: `{"tokenizer": "jieba"}`}, - }, - }, + FieldID: 101, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "enable_match", Value: "true"}, + {Key: "tokenizer_params", Value: `{"tokenizer": "jieba"}`}, }, - errIsNil: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateTextSchema(tt.args.fieldSchema) - if tt.errIsNil { - assert.Nil(t, err) - } else { - assert.NotNil(t, err) - } + + for idx, tt := range tests { + t.Run(fmt.Sprintf("enable_tokenizer not set %d", idx), func(t *testing.T) { + err := ValidateTextSchema(tt) + assert.NotNil(t, err) + }) + } + + for idx, tt := range tests { + t.Run(fmt.Sprintf("enable_tokenizer set to false %d", idx), func(t *testing.T) { + tt.TypeParams = append(tt.TypeParams, &commonpb.KeyValuePair{ + Key: "enable_tokenizer", + Value: "false", + }) + err := ValidateTextSchema(tt) + assert.NotNil(t, err) + }) + } + for idx, tt := range tests { + t.Run(fmt.Sprintf("enable_tokenizer set to true %d", idx), func(t *testing.T) { + tt.TypeParams[len(tt.TypeParams)-1].Value = "true" + err := ValidateTextSchema(tt) + assert.Nil(t, err) }) } } diff --git a/pkg/util/typeutil/field_schema.go b/pkg/util/typeutil/field_schema.go index b5cb4b240ceb6..bbf3ab446719a 100644 --- a/pkg/util/typeutil/field_schema.go +++ b/pkg/util/typeutil/field_schema.go @@ -53,6 +53,18 @@ func (h *FieldSchemaHelper) EnableMatch() bool { return err == nil && enable } +func (h *FieldSchemaHelper) EnableTokenizer() bool { + if !IsStringType(h.schema.GetDataType()) { + return false + } + s, err := h.typeParams.Get("enable_tokenizer") + if err != nil { + return false + } + enable, err := strconv.ParseBool(s) + return err == nil && enable +} + func CreateFieldSchemaHelper(schema *schemapb.FieldSchema) *FieldSchemaHelper { return &FieldSchemaHelper{ schema: schema,