Skip to content

Commit

Permalink
Trying to fix issue with NaN, Inf, -Inf
Browse files Browse the repository at this point in the history
jgiannuzzi/mlflow-go#45

Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com>
  • Loading branch information
dsuhinin authored Sep 24, 2024
1 parent 258a61e commit e0dbbb2
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 149 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,12 @@ Sometimes `golangci-lint` can complain about unrelated files, run `golangci-lint
The following Python tests are currently failing:

```
========================================================================================================================= short test summary info ==========================================================================================================================
FAILED .mlflow.repo/tests/tracking/test_rest_tracking.py::test_log_metrics_params_tags[sqlalchemy] - mlflow.exceptions.RestException: INVALID_PARAMETER_VALUE: Invalid value "NaN" for parameter 'value' supplied
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_metric_concurrent_logging_succeeds - mlflow.exceptions.MlflowException: error creating metrics in batch for run_uuid "23ad9801cd064fa1bc14b458ed5f609b"
============================================================================= short test summary info ==============================================================================
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_metric_concurrent_logging_succeeds - mlflow.exceptions.MlflowException: error creating metrics in batch for run_uuid "de412cd72bcb4268a4d90d6c52ce4d3c"
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_batch_null_metrics - TypeError: must be real number, not NoneType
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_log_inputs_with_large_inputs_limit_check - AssertionError: assert {'digest': 'd...ema': '', ...} == {'digest': 'd...a': None, ...}
FAILED .mlflow.repo/tests/store/tracking/test_sqlalchemy_store.py::test_sqlalchemy_store_behaves_as_expected_with_inmemory_sqlite_db - mlflow.exceptions.MlflowException: failed to create experiment
============================================================================================ 5 failed, 354 passed, 9 skipped, 128 deselected, 10 warnings in 231.73s (0:03:51) =============================================================================================
================================================ 4 failed, 355 passed, 9 skipped, 128 deselected, 10 warnings in 557.37s (0:09:17) =================================================
```

## Debug Failing Tests
Expand Down
15 changes: 14 additions & 1 deletion pkg/entities/metric.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
package entities

import (
"math"

"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type Metric struct {
Key string
Value float64
Timestamp int64
Step int64
IsNaN bool
}

func (m Metric) ToProto() *protos.Metric {
return &protos.Metric{
metric := protos.Metric{
Key: &m.Key,
Value: &m.Value,
Timestamp: &m.Timestamp,
Step: &m.Step,
}

switch {
case m.IsNaN:
metric.Value = utils.PtrTo(math.NaN())
default:
metric.Value = &m.Value
}

return &metric
}

func MetricFromProto(proto *protos.Metric) *Metric {
Expand Down
28 changes: 18 additions & 10 deletions pkg/server/parser/http_request_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/tidwall/gjson"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/protos"
Expand All @@ -29,19 +31,25 @@ func NewHTTPRequestParser() (*HTTPRequestParser, error) {
}, nil
}

func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input interface{}) *contract.Error {
if err := ctx.BodyParser(input); err != nil {
var unmarshalTypeError *json.UnmarshalTypeError
if errors.As(err, &unmarshalTypeError) {
result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field)
func (p *HTTPRequestParser) ParseBody(ctx *fiber.Ctx, input proto.Message) *contract.Error {
if protojsonErr := protojson.Unmarshal(ctx.Body(), input); protojsonErr != nil {
// falling back to JSON, because `protojson` doesn't provide any information
// about `field` name for which ut fails. MLFlow tests expect to know the exact
// `field` name where validation failed. This approach has no effect on MLFlow
// tests, so let's keep it for now.
if jsonErr := json.Unmarshal(ctx.Body(), input); jsonErr != nil {
var unmarshalTypeError *json.UnmarshalTypeError
if errors.As(jsonErr, &unmarshalTypeError) {
result := gjson.GetBytes(ctx.Body(), unmarshalTypeError.Field)

return contract.NewError(
protos.ErrorCode_INVALID_PARAMETER_VALUE,
fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field),
)
return contract.NewError(
protos.ErrorCode_INVALID_PARAMETER_VALUE,
fmt.Sprintf("Invalid value %s for parameter '%s' supplied", result.Raw, unmarshalTypeError.Field),
)
}
}

return contract.NewError(protos.ErrorCode_BAD_REQUEST, err.Error())
return contract.NewError(protos.ErrorCode_BAD_REQUEST, protojsonErr.Error())
}

if err := p.validator.Struct(input); err != nil {
Expand Down
30 changes: 24 additions & 6 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/proxy"
"github.com/gofiber/fiber/v2/middleware/recover"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

as "github.com/mlflow/mlflow-go/pkg/artifacts/service"
mr "github.com/mlflow/mlflow-go/pkg/model_registry/service"
Expand All @@ -26,15 +29,30 @@ import (
"github.com/mlflow/mlflow-go/pkg/utils"
)

//nolint:funlen
func configureApp(ctx context.Context, cfg *config.Config) (*fiber.App, error) {
//nolint:mnd
app := fiber.New(fiber.Config{
BodyLimit: 16 * 1024 * 1024,
ReadBufferSize: 16384,
ReadTimeout: 5 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 120 * time.Second,
ServerHeader: "mlflow/" + cfg.Version,
BodyLimit: 16 * 1024 * 1024,
ReadBufferSize: 16384,
ReadTimeout: 5 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 120 * time.Second,
ServerHeader: "mlflow/" + cfg.Version,
JSONEncoder: func(v interface{}) ([]byte, error) {
if protoMessage, ok := v.(proto.Message); ok {
return protojson.Marshal(protoMessage)
}

return json.Marshal(v)
},
JSONDecoder: func(data []byte, v interface{}) error {
if protoMessage, ok := v.(proto.Message); ok {
return protojson.Unmarshal(data, protoMessage)
}

return json.Unmarshal(data, v)
},
DisableStartupMessage: true,
})

Expand Down
Loading

0 comments on commit e0dbbb2

Please # to comment.