Skip to content

Commit

Permalink
Add TypeMerger (#445)
Browse files Browse the repository at this point in the history
This is called `TypeMerger` but it is really multiple dispatch for functions with type `([some Type], [some Type]) → Type`. Ideally this would be a bit more generalized to _n_ arguments but Go doesn’t support that kind of thing.
  • Loading branch information
Porges authored Apr 30, 2021
1 parent 351c853 commit 8f8b86c
Show file tree
Hide file tree
Showing 4 changed files with 397 additions and 171 deletions.
174 changes: 174 additions & 0 deletions hack/generator/pkg/astmodel/type_merger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

package astmodel

import (
"reflect"

"github.com/pkg/errors"
)

// TypeMerger is like a visitor for 2 types.
//
// Conceptually it takes (via Add) a list of functions of the form:
//
// func ([ctx interface{},] left {some Type}, right {some Type}) (Type, error)
//
// where `left` and `right` can be concrete types that implement the `Type` interface.
//
// When `TypeMerger.Merge(Type, Type)` is invoked, it will iterate through the
// provided functions and invoke the first one that matches the concrete types
// passed. If none match then the fallback provided to `NewTypeMerger` will be invoked.
//
// The `ctx` argument can optionally be used to “smuggle” additional data down the call-chain.
type TypeMerger struct {
mergers []mergerRegistration
fallback MergerFunc
}

type mergerRegistration struct {
left reflect.Type
right reflect.Type
merge MergerFunc
}

type MergerFunc func(ctx interface{}, left, right Type) (Type, error)

func NewTypeMerger(fallback MergerFunc) TypeMerger {
return TypeMerger{fallback: fallback}
}

var typeInterface reflect.Type = reflect.TypeOf((*Type)(nil)).Elem() // yuck
var errorInterface reflect.Type = reflect.TypeOf((*error)(nil)).Elem()
var mergerFuncType reflect.Type = reflect.TypeOf((*MergerFunc)(nil)).Elem()

type validatedMerger struct {
merger reflect.Value
leftArgType, rightArgType reflect.Type
needsCtx bool
}

func validateMerger(merger interface{}) validatedMerger {
it := reflect.ValueOf(merger)
if it.Kind() != reflect.Func {
panic("merger must be a function")
}

mergerType := it.Type()

badArgumentsMsg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])"
if mergerType.NumIn() < 2 || mergerType.NumIn() > 3 {
panic(badArgumentsMsg)
}

argOffset := mergerType.NumIn() - 2

needsCtx := argOffset != 0
leftArg := mergerType.In(argOffset + 0)
rightArg := mergerType.In(argOffset + 1)

if !leftArg.AssignableTo(typeInterface) || !rightArg.AssignableTo(typeInterface) {
panic(badArgumentsMsg)
}

if mergerType.NumOut() != 2 ||
mergerType.Out(0) != typeInterface ||
mergerType.Out(1) != errorInterface {
panic("merger must return (Type, error)")
}

return validatedMerger{
merger: it,
leftArgType: leftArg,
rightArgType: rightArg,
needsCtx: needsCtx,
}
}

func buildMergerRegistration(v validatedMerger, flip bool) mergerRegistration {
leftArgType := v.leftArgType
rightArgType := v.rightArgType
if flip {
leftArgType, rightArgType = rightArgType, leftArgType
}

return mergerRegistration{
left: leftArgType,
right: rightArgType,
merge: reflect.MakeFunc(mergerFuncType, func(args []reflect.Value) []reflect.Value {
// we dereference the Type here to the underlying value so that
// the merger can take either Type or a specific type.
// if it takes Type then the compiler/runtime will convert the value back to a Type
ctxValue := args[0].Elem()
leftValue := args[1].Elem()
rightValue := args[2].Elem()
if flip {
leftValue, rightValue = rightValue, leftValue
}

if v.needsCtx {
return v.merger.Call([]reflect.Value{ctxValue, leftValue, rightValue})
} else {
return v.merger.Call([]reflect.Value{leftValue, rightValue})
}
}).Interface().(MergerFunc),
}
}

// Add adds a handler function to be invoked if applicable. See the docs on
// TypeMerger above for a full explanation.
func (m *TypeMerger) Add(mergeFunc interface{}) {
v := validateMerger(mergeFunc)
m.mergers = append(m.mergers, buildMergerRegistration(v, false))
}

// AddUnordered adds a handler function that doesn’t care what order
// the two type parameters are in. e.g. if it has type `(A, B) -> (Type, error)`,
// it can match either `(A, B)` or `(B, A)`. This is useful when the merger
// is symmetric and handles two different types.
func (m *TypeMerger) AddUnordered(mergeFunc interface{}) {
v := validateMerger(mergeFunc)
m.mergers = append(m.mergers, buildMergerRegistration(v, false), buildMergerRegistration(v, true))
}

// Merge merges the two types according to the provided mergers and fallback, with nil context
func (m *TypeMerger) Merge(left, right Type) (Type, error) {
return m.MergeWithContext(nil, left, right)
}

// MergeWithContext merges the two types according to the provided mergers and fallback, with the provided context
func (m *TypeMerger) MergeWithContext(ctx interface{}, left, right Type) (Type, error) {
if left == nil {
return right, nil
}

if right == nil {
return left, nil
}

leftType := reflect.ValueOf(left).Type()
rightType := reflect.ValueOf(right).Type()

for _, merger := range m.mergers {
leftTypeMatches := merger.left == leftType || merger.left == typeInterface
rightTypeMatches := merger.right == rightType || merger.right == typeInterface

if leftTypeMatches && rightTypeMatches {
result, err := merger.merge(ctx, left, right)
if (result == nil && err == nil) || errors.Is(err, ContinueMerge) {
// these conditions indicate that the merger was not actually applicable,
// despite having a type that matched
continue
}

return result, err
}
}

return m.fallback(ctx, left, right)
}

var ContinueMerge error = errors.New("special error that indicates that the merger was not applicable")
168 changes: 168 additions & 0 deletions hack/generator/pkg/astmodel/type_merger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

package astmodel

import (
"testing"

. "github.com/onsi/gomega"
"github.com/pkg/errors"
)

func TestCanMergeSameTypes(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) {
return nil, errors.New("reached fallback")
})

merger.Add(func(l, r *PrimitiveType) (Type, error) {
return StringType, nil
})

result, err := merger.Merge(IntType, BoolType)
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(StringType))
}

func TestCanMergeDifferentTypes(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) {
return BoolType, nil
})

merger.Add(func(l *ObjectType, r *PrimitiveType) (Type, error) {
return StringType, nil
})

result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(StringType))

result, err = merger.Merge(IntType, NewObjectType()) // should hit fallback
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(BoolType))
}

func TestCanMergeWithGenericTypeArgument(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) {
return BoolType, nil
})

merger.Add(func(l *ObjectType, r Type) (Type, error) {
return StringType, nil
})

result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(StringType))

result, err = merger.Merge(IntType, NewObjectType()) // should hit fallback
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(BoolType))
}

func TestCanMergeWithUnorderedMerger(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) {
return BoolType, nil
})

merger.AddUnordered(func(l *ObjectType, r Type) (Type, error) {
return StringType, nil
})

result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(StringType))

result, err = merger.Merge(IntType, NewObjectType()) // shouldn't hit fallback either
g.Expect(err).To(Not(HaveOccurred()))
g.Expect(result).To(Equal(StringType))
}

var leftFallback MergerFunc = func(ctx interface{}, left, right Type) (Type, error) { return left, nil }

func TestAddPanicsWhenPassedANonFunction(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(leftFallback)

g.Expect(func() { merger.Add(123) }).To(PanicWith("merger must be a function"))
}

func TestMergerFuncMustTakeTwoOrThreeArguments(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(leftFallback)

msg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])"

g.Expect(func() { merger.Add(func() (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(x Type) (Type, error) { return x, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(x, _, _, _ Type) (Type, error) { return x, nil }) }).To(PanicWith(msg))
}

func TestMergerFuncMustTakeTypesAssignableToTypeAsArguments(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(leftFallback)

msg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])"

// left side wrong
g.Expect(func() { merger.Add(func(_ int, _ Type) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ int, _ EnumType) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ interface{}, x int, _ Type) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ interface{}, _ int, _ EnumType) (Type, error) { return nil, nil }) }).To(PanicWith(msg))

// right side wrong
g.Expect(func() { merger.Add(func(_ Type, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ EnumType, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ interface{}, _ Type, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ interface{}, _ EnumType, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg))
}

func TestFuncMustTakeTypesAssignableToTypeAsArguments(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(leftFallback)

msg := "merger must return (Type, error)"

g.Expect(func() { merger.Add(func(_, _ Type) Type { return nil }) }).To(PanicWith(msg))
g.Expect(func() { merger.Add(func(_ interface{}, _, _ Type) Type { return nil }) }).To(PanicWith(msg))
}

func TestMergeReturnsNonNilSide(t *testing.T) {
g := NewGomegaWithT(t)

merger := NewTypeMerger(leftFallback)

var ctx interface{} = nil

// left side
result, err := merger.Merge(StringType, nil)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(result).To(Equal(StringType))

result, err = merger.MergeWithContext(ctx, StringType, nil)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(result).To(Equal(StringType))

// right side
result, err = merger.Merge(nil, StringType)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(result).To(Equal(StringType))

result, err = merger.MergeWithContext(ctx, nil, StringType)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(result).To(Equal(StringType))
}
2 changes: 1 addition & 1 deletion hack/generator/pkg/codegen/pipeline_augment_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ func generateStatusTypes(swaggerTypes swaggerTypes) (statusTypes, error) {
return astmodel.MakeTypeName(typeName.PackageReference, typeName.Name()+"_Status")
}

var errs []error
renamer := makeRenamingVisitor(appendStatusToName)

var errs []error
var otherTypes []astmodel.TypeDefinition
for _, typeDef := range swaggerTypes.otherTypes {
renamedDef, err := renamer.VisitDefinition(typeDef, nil)
Expand Down
Loading

0 comments on commit 8f8b86c

Please # to comment.