diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index df952c41..45bc66d7 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -63,7 +63,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ // "searchModelVersions", // "getModelVersionDownloadUri", "setRegisteredModelTag", - // "setModelVersionTag", + "setModelVersionTag", "deleteRegisteredModelTag", // "deleteModelVersionTag", "setRegisteredModelAlias", diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 9a5ac7bf..c55a9f1e 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -64,4 +64,8 @@ var validations = map[string]string{ "DeleteRegisteredModelAlias_Alias": "required,max=255,validMetricParamOrTagName,pathIsUnique", "GetModelVersionByAlias_Name": "required", "GetModelVersionByAlias_Alias": "required,max=255,validMetricParamOrTagName,pathIsUnique", + "SetModelVersionTag_Name": "required", + "SetModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "SetModelVersionTag_Value": "omitempty,max=5000,truncate=5000", + "SetModelVersionTag_Version": "stringAsInteger", } diff --git a/mlflow_go_backend/store/model_registry.py b/mlflow_go_backend/store/model_registry.py index 767fbf10..7c7f3c9f 100644 --- a/mlflow_go_backend/store/model_registry.py +++ b/mlflow_go_backend/store/model_registry.py @@ -13,6 +13,7 @@ GetModelVersionByAlias, GetRegisteredModel, RenameRegisteredModel, + SetModelVersionTag, SetRegisteredModelAlias, SetRegisteredModelTag, TransitionModelVersionStage, @@ -158,6 +159,10 @@ def get_model_version_by_alias(self, name, alias): ) return ModelVersion.from_proto(response.model_version) + def set_model_version_tag(self, name, version, tag): + request = SetModelVersionTag(name=name, version=str(version), key=tag.key, value=tag.value) + self.service.call_endpoint(get_lib().ModelRegistryServiceSetModelVersionTag, request) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 5fadc1e3..d4e8f6f3 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -21,6 +21,7 @@ type ModelRegistryService interface { DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) GetModelVersion(ctx context.Context, input *protos.GetModelVersion) (*protos.GetModelVersion_Response, *contract.Error) SetRegisteredModelTag(ctx context.Context, input *protos.SetRegisteredModelTag) (*protos.SetRegisteredModelTag_Response, *contract.Error) + SetModelVersionTag(ctx context.Context, input *protos.SetModelVersionTag) (*protos.SetModelVersionTag_Response, *contract.Error) DeleteRegisteredModelTag(ctx context.Context, input *protos.DeleteRegisteredModelTag) (*protos.DeleteRegisteredModelTag_Response, *contract.Error) SetRegisteredModelAlias(ctx context.Context, input *protos.SetRegisteredModelAlias) (*protos.SetRegisteredModelAlias_Response, *contract.Error) DeleteRegisteredModelAlias(ctx context.Context, input *protos.DeleteRegisteredModelAlias) (*protos.DeleteRegisteredModelAlias_Response, *contract.Error) diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 187d85e6..11dcac5c 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -95,6 +95,14 @@ func ModelRegistryServiceSetRegisteredModelTag(serviceID int64, requestData unsa } return invokeServiceMethod(service.SetRegisteredModelTag, new(protos.SetRegisteredModelTag), requestData, requestSize, responseSize) } +//export ModelRegistryServiceSetModelVersionTag +func ModelRegistryServiceSetModelVersionTag(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.SetModelVersionTag, new(protos.SetModelVersionTag), requestData, requestSize, responseSize) +} //export ModelRegistryServiceDeleteRegisteredModelTag func ModelRegistryServiceDeleteRegisteredModelTag(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index 4e1fedcd..5fbb3cc2 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -114,3 +114,19 @@ func (m *ModelRegistryService) GetModelVersionByAlias( ModelVersion: modelVersion.ToProto(), }, nil } + +func (m *ModelRegistryService) SetModelVersionTag( + ctx context.Context, input *protos.SetModelVersionTag, +) (*protos.SetModelVersionTag_Response, *contract.Error) { + if err := m.store.SetModelVersionTag( + ctx, + input.GetName(), + input.GetVersion(), + input.GetKey(), + input.GetValue(), + ); err != nil { + return nil, err + } + + return &protos.SetModelVersionTag_Response{}, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 86568beb..0d63f96f 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -10,6 +10,7 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/mlflow/mlflow-go-backend/pkg/contract" "github.com/mlflow/mlflow-go-backend/pkg/entities" @@ -333,3 +334,29 @@ func (m *ModelRegistrySQLStore) GetModelVersionByAlias( return modelVersion, nil } + +func (m *ModelRegistrySQLStore) SetModelVersionTag( + ctx context.Context, name, version, key, value string, +) *contract.Error { + modelVersion, err := m.GetModelVersion(ctx, name, version, false) + if err != nil { + return err + } + + if err := m.db.WithContext( + ctx, + ).Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(&models.ModelVersionTag{ + Key: key, + Value: value, + Name: modelVersion.Name, + Version: modelVersion.Version, + }).Error; err != nil { + return contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, "error setting model version tag", err, + ) + } + + return nil +} diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index 9776003c..a9e0f0cf 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -24,6 +24,7 @@ type ModelVersionStore interface { ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool, ) (*entities.ModelVersion, *contract.Error) GetModelVersionByAlias(ctx context.Context, name, alias string) (*entities.ModelVersion, *contract.Error) + SetModelVersionTag(ctx context.Context, name, version, key, value string) *contract.Error } type RegisteredModelStore interface { diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index bebc6072..43e70b70 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -1427,15 +1427,15 @@ type SetModelVersionTag struct { unknownFields protoimpl.UnknownFields // Unique name of the model. - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"required"` // Model version number. - Version *string `protobuf:"bytes,2,opt,name=version" json:"version,omitempty" query:"version" params:"version"` + Version *string `protobuf:"bytes,2,opt,name=version" json:"version,omitempty" query:"version" params:"version" validate:"stringAsInteger"` // Name of the tag. Maximum size depends on storage backend. // If a tag with this name already exists, its preexisting value will be replaced by the specified `value`. // All storage backends are guaranteed to support key values up to 250 bytes in size. - Key *string `protobuf:"bytes,3,opt,name=key" json:"key,omitempty" query:"key" params:"key"` + Key *string `protobuf:"bytes,3,opt,name=key" json:"key,omitempty" query:"key" params:"key" validate:"required,max=250,validMetricParamOrTagName,pathIsUnique"` // String value of the tag being logged. Maximum size depends on storage backend. - Value *string `protobuf:"bytes,4,opt,name=value" json:"value,omitempty" query:"value" params:"value"` + Value *string `protobuf:"bytes,4,opt,name=value" json:"value,omitempty" query:"value" params:"value" validate:"omitempty,max=5000,truncate=5000"` } func (x *SetModelVersionTag) Reset() { diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index 4e9ca38b..a5d26532 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -143,6 +143,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Post("/mlflow/model-versions/set-tag", func(ctx *fiber.Ctx) error { + input := &protos.SetModelVersionTag{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.SetModelVersionTag(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Delete("/mlflow/registered-models/delete-tag", func(ctx *fiber.Ctx) error { input := &protos.DeleteRegisteredModelTag{} if err := parser.ParseBody(ctx, input); err != nil { diff --git a/pkg/validation/validation.go b/pkg/validation/validation.go index dd08038c..c7a3f3aa 100644 --- a/pkg/validation/validation.go +++ b/pkg/validation/validation.go @@ -91,6 +91,14 @@ func notEmptyValidation(fl validator.FieldLevel) bool { return fl.Field().String() != "" } +func stringAsInteger(fl validator.FieldLevel) bool { + if _, err := strconv.Atoi(fl.Field().String()); err != nil { + return false + } + + return true +} + func regexValidation(regex *regexp.Regexp) validator.Func { return func(fl validator.FieldLevel) bool { valueStr := fl.Field().String() @@ -150,7 +158,7 @@ func truncateFn(fieldLevel validator.FieldLevel) bool { return true } -//nolint:cyclop +//nolint:cyclop,funlen func NewValidator() (*validator.Validate, error) { validate := validator.New() @@ -208,6 +216,10 @@ func NewValidator() (*validator.Validate, error) { return nil, fmt.Errorf("validation registration for 'notEmpty' failed: %w", err) } + if err := validate.RegisterValidation("stringAsInteger", stringAsInteger); err != nil { + return nil, fmt.Errorf("validation registration for 'notEmpty' failed: %w", err) + } + validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{}) validate.RegisterStructValidation(validateSetTagRunIDExists, &protos.SetTag{}) @@ -324,6 +336,11 @@ func NewErrorFromValidationError(err error) *contract.Error { validationErrors = append(validationErrors, mkMaxValidationError(field, value, err)) case "positiveNonZeroInteger": validationErrors = append(validationErrors, mkPositiveNonZeroIntegerError(field, value)) + case "stringAsInteger": + validationErrors = append( + validationErrors, + fmt.Sprintf("Parameter '%s' must be an integer, got '%s'", field, value), + ) default: validationErrors = append( validationErrors,