Skip to content

Commit

Permalink
Move POST /mlflow/model-versions/set-tag endpoint.
Browse files Browse the repository at this point in the history
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
  • Loading branch information
dsuhinin committed Jan 28, 2025
1 parent 1fea027 commit 103dba9
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 6 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
// "searchModelVersions",
// "getModelVersionDownloadUri",
"setRegisteredModelTag",
// "setModelVersionTag",
"setModelVersionTag",
"deleteRegisteredModelTag",
// "deleteModelVersionTag",
"setRegisteredModelAlias",
Expand Down
4 changes: 4 additions & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ var validations = map[string]string{
"DeleteRegisteredModelAlias_Alias": "required,max=255,validMetricParamOrTagName,pathIsUnique",
"GetModelVersionByAlias_Name": "required",
"GetModelVersionByAlias_Alias": "required,max=255,validMetricParamOrTagName,pathIsUnique",
"SetModelVersionTag_Name": "required",
"SetModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"SetModelVersionTag_Value": "omitempty,max=5000,truncate=5000",
"SetModelVersionTag_Version": "stringAsInteger",
}
5 changes: 5 additions & 0 deletions mlflow_go_backend/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
GetModelVersionByAlias,
GetRegisteredModel,
RenameRegisteredModel,
SetModelVersionTag,
SetRegisteredModelAlias,
SetRegisteredModelTag,
TransitionModelVersionStage,
Expand Down Expand Up @@ -158,6 +159,10 @@ def get_model_version_by_alias(self, name, alias):
)
return ModelVersion.from_proto(response.model_version)

def set_model_version_tag(self, name, version, tag):
request = SetModelVersionTag(name=name, version=str(version), key=tag.key, value=tag.value)
self.service.call_endpoint(get_lib().ModelRegistryServiceSetModelVersionTag, request)


def ModelRegistryStore(cls):
return type(cls.__name__, (_ModelRegistryStore, cls), {})
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,19 @@ func (m *ModelRegistryService) GetModelVersionByAlias(
ModelVersion: modelVersion.ToProto(),
}, nil
}

func (m *ModelRegistryService) SetModelVersionTag(
ctx context.Context, input *protos.SetModelVersionTag,
) (*protos.SetModelVersionTag_Response, *contract.Error) {
if err := m.store.SetModelVersionTag(
ctx,
input.GetName(),
input.GetVersion(),
input.GetKey(),
input.GetValue(),
); err != nil {
return nil, err
}

return &protos.SetModelVersionTag_Response{}, nil
}
27 changes: 27 additions & 0 deletions pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"gorm.io/gorm"
"gorm.io/gorm/clause"

"github.com/mlflow/mlflow-go-backend/pkg/contract"
"github.com/mlflow/mlflow-go-backend/pkg/entities"
Expand Down Expand Up @@ -333,3 +334,29 @@ func (m *ModelRegistrySQLStore) GetModelVersionByAlias(

return modelVersion, nil
}

func (m *ModelRegistrySQLStore) SetModelVersionTag(
ctx context.Context, name, version, key, value string,
) *contract.Error {
modelVersion, err := m.GetModelVersion(ctx, name, version, false)
if err != nil {
return err
}

if err := m.db.WithContext(
ctx,
).Clauses(clause.OnConflict{
UpdateAll: true,
}).Create(&models.ModelVersionTag{
Key: key,
Value: value,
Name: modelVersion.Name,
Version: modelVersion.Version,
}).Error; err != nil {
return contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR, "error setting model version tag", err,
)
}

return nil
}
1 change: 1 addition & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type ModelVersionStore interface {
ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool,
) (*entities.ModelVersion, *contract.Error)
GetModelVersionByAlias(ctx context.Context, name, alias string) (*entities.ModelVersion, *contract.Error)
SetModelVersionTag(ctx context.Context, name, version, key, value string) *contract.Error
}

type RegisteredModelStore interface {
Expand Down
8 changes: 4 additions & 4 deletions pkg/protos/model_registry.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions pkg/server/routes/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 18 additions & 1 deletion pkg/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ func notEmptyValidation(fl validator.FieldLevel) bool {
return fl.Field().String() != ""
}

func stringAsInteger(fl validator.FieldLevel) bool {
if _, err := strconv.Atoi(fl.Field().String()); err != nil {
return false
}

return true
}

func regexValidation(regex *regexp.Regexp) validator.Func {
return func(fl validator.FieldLevel) bool {
valueStr := fl.Field().String()
Expand Down Expand Up @@ -150,7 +158,7 @@ func truncateFn(fieldLevel validator.FieldLevel) bool {
return true
}

//nolint:cyclop
//nolint:cyclop,funlen
func NewValidator() (*validator.Validate, error) {
validate := validator.New()

Expand Down Expand Up @@ -208,6 +216,10 @@ func NewValidator() (*validator.Validate, error) {
return nil, fmt.Errorf("validation registration for 'notEmpty' failed: %w", err)
}

if err := validate.RegisterValidation("stringAsInteger", stringAsInteger); err != nil {
return nil, fmt.Errorf("validation registration for 'notEmpty' failed: %w", err)
}

validate.RegisterStructValidation(validateLogBatchLimits, &protos.LogBatch{})
validate.RegisterStructValidation(validateSetTagRunIDExists, &protos.SetTag{})

Expand Down Expand Up @@ -324,6 +336,11 @@ func NewErrorFromValidationError(err error) *contract.Error {
validationErrors = append(validationErrors, mkMaxValidationError(field, value, err))
case "positiveNonZeroInteger":
validationErrors = append(validationErrors, mkPositiveNonZeroIntegerError(field, value))
case "stringAsInteger":
validationErrors = append(
validationErrors,
fmt.Sprintf("Parameter '%s' must be an integer, got '%s'", field, value),
)
default:
validationErrors = append(
validationErrors,
Expand Down

0 comments on commit 103dba9

Please # to comment.