-
Notifications
You must be signed in to change notification settings - Fork 209
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is called `TypeMerger` but it is really multiple dispatch for functions with type `([some Type], [some Type]) → Type`. Ideally this would be a bit more generalized to _n_ arguments but Go doesn’t support that kind of thing.
- Loading branch information
Showing
4 changed files
with
397 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
/* | ||
* Copyright (c) Microsoft Corporation. | ||
* Licensed under the MIT license. | ||
*/ | ||
|
||
package astmodel | ||
|
||
import ( | ||
"reflect" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
// TypeMerger is like a visitor for 2 types. | ||
// | ||
// Conceptually it takes (via Add) a list of functions of the form: | ||
// | ||
// func ([ctx interface{},] left {some Type}, right {some Type}) (Type, error) | ||
// | ||
// where `left` and `right` can be concrete types that implement the `Type` interface. | ||
// | ||
// When `TypeMerger.Merge(Type, Type)` is invoked, it will iterate through the | ||
// provided functions and invoke the first one that matches the concrete types | ||
// passed. If none match then the fallback provided to `NewTypeMerger` will be invoked. | ||
// | ||
// The `ctx` argument can optionally be used to “smuggle” additional data down the call-chain. | ||
type TypeMerger struct { | ||
mergers []mergerRegistration | ||
fallback MergerFunc | ||
} | ||
|
||
type mergerRegistration struct { | ||
left reflect.Type | ||
right reflect.Type | ||
merge MergerFunc | ||
} | ||
|
||
type MergerFunc func(ctx interface{}, left, right Type) (Type, error) | ||
|
||
func NewTypeMerger(fallback MergerFunc) TypeMerger { | ||
return TypeMerger{fallback: fallback} | ||
} | ||
|
||
var typeInterface reflect.Type = reflect.TypeOf((*Type)(nil)).Elem() // yuck | ||
var errorInterface reflect.Type = reflect.TypeOf((*error)(nil)).Elem() | ||
var mergerFuncType reflect.Type = reflect.TypeOf((*MergerFunc)(nil)).Elem() | ||
|
||
type validatedMerger struct { | ||
merger reflect.Value | ||
leftArgType, rightArgType reflect.Type | ||
needsCtx bool | ||
} | ||
|
||
func validateMerger(merger interface{}) validatedMerger { | ||
it := reflect.ValueOf(merger) | ||
if it.Kind() != reflect.Func { | ||
panic("merger must be a function") | ||
} | ||
|
||
mergerType := it.Type() | ||
|
||
badArgumentsMsg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])" | ||
if mergerType.NumIn() < 2 || mergerType.NumIn() > 3 { | ||
panic(badArgumentsMsg) | ||
} | ||
|
||
argOffset := mergerType.NumIn() - 2 | ||
|
||
needsCtx := argOffset != 0 | ||
leftArg := mergerType.In(argOffset + 0) | ||
rightArg := mergerType.In(argOffset + 1) | ||
|
||
if !leftArg.AssignableTo(typeInterface) || !rightArg.AssignableTo(typeInterface) { | ||
panic(badArgumentsMsg) | ||
} | ||
|
||
if mergerType.NumOut() != 2 || | ||
mergerType.Out(0) != typeInterface || | ||
mergerType.Out(1) != errorInterface { | ||
panic("merger must return (Type, error)") | ||
} | ||
|
||
return validatedMerger{ | ||
merger: it, | ||
leftArgType: leftArg, | ||
rightArgType: rightArg, | ||
needsCtx: needsCtx, | ||
} | ||
} | ||
|
||
func buildMergerRegistration(v validatedMerger, flip bool) mergerRegistration { | ||
leftArgType := v.leftArgType | ||
rightArgType := v.rightArgType | ||
if flip { | ||
leftArgType, rightArgType = rightArgType, leftArgType | ||
} | ||
|
||
return mergerRegistration{ | ||
left: leftArgType, | ||
right: rightArgType, | ||
merge: reflect.MakeFunc(mergerFuncType, func(args []reflect.Value) []reflect.Value { | ||
// we dereference the Type here to the underlying value so that | ||
// the merger can take either Type or a specific type. | ||
// if it takes Type then the compiler/runtime will convert the value back to a Type | ||
ctxValue := args[0].Elem() | ||
leftValue := args[1].Elem() | ||
rightValue := args[2].Elem() | ||
if flip { | ||
leftValue, rightValue = rightValue, leftValue | ||
} | ||
|
||
if v.needsCtx { | ||
return v.merger.Call([]reflect.Value{ctxValue, leftValue, rightValue}) | ||
} else { | ||
return v.merger.Call([]reflect.Value{leftValue, rightValue}) | ||
} | ||
}).Interface().(MergerFunc), | ||
} | ||
} | ||
|
||
// Add adds a handler function to be invoked if applicable. See the docs on | ||
// TypeMerger above for a full explanation. | ||
func (m *TypeMerger) Add(mergeFunc interface{}) { | ||
v := validateMerger(mergeFunc) | ||
m.mergers = append(m.mergers, buildMergerRegistration(v, false)) | ||
} | ||
|
||
// AddUnordered adds a handler function that doesn’t care what order | ||
// the two type parameters are in. e.g. if it has type `(A, B) -> (Type, error)`, | ||
// it can match either `(A, B)` or `(B, A)`. This is useful when the merger | ||
// is symmetric and handles two different types. | ||
func (m *TypeMerger) AddUnordered(mergeFunc interface{}) { | ||
v := validateMerger(mergeFunc) | ||
m.mergers = append(m.mergers, buildMergerRegistration(v, false), buildMergerRegistration(v, true)) | ||
} | ||
|
||
// Merge merges the two types according to the provided mergers and fallback, with nil context | ||
func (m *TypeMerger) Merge(left, right Type) (Type, error) { | ||
return m.MergeWithContext(nil, left, right) | ||
} | ||
|
||
// MergeWithContext merges the two types according to the provided mergers and fallback, with the provided context | ||
func (m *TypeMerger) MergeWithContext(ctx interface{}, left, right Type) (Type, error) { | ||
if left == nil { | ||
return right, nil | ||
} | ||
|
||
if right == nil { | ||
return left, nil | ||
} | ||
|
||
leftType := reflect.ValueOf(left).Type() | ||
rightType := reflect.ValueOf(right).Type() | ||
|
||
for _, merger := range m.mergers { | ||
leftTypeMatches := merger.left == leftType || merger.left == typeInterface | ||
rightTypeMatches := merger.right == rightType || merger.right == typeInterface | ||
|
||
if leftTypeMatches && rightTypeMatches { | ||
result, err := merger.merge(ctx, left, right) | ||
if (result == nil && err == nil) || errors.Is(err, ContinueMerge) { | ||
// these conditions indicate that the merger was not actually applicable, | ||
// despite having a type that matched | ||
continue | ||
} | ||
|
||
return result, err | ||
} | ||
} | ||
|
||
return m.fallback(ctx, left, right) | ||
} | ||
|
||
var ContinueMerge error = errors.New("special error that indicates that the merger was not applicable") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
/* | ||
* Copyright (c) Microsoft Corporation. | ||
* Licensed under the MIT license. | ||
*/ | ||
|
||
package astmodel | ||
|
||
import ( | ||
"testing" | ||
|
||
. "github.com/onsi/gomega" | ||
"github.com/pkg/errors" | ||
) | ||
|
||
func TestCanMergeSameTypes(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) { | ||
return nil, errors.New("reached fallback") | ||
}) | ||
|
||
merger.Add(func(l, r *PrimitiveType) (Type, error) { | ||
return StringType, nil | ||
}) | ||
|
||
result, err := merger.Merge(IntType, BoolType) | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(StringType)) | ||
} | ||
|
||
func TestCanMergeDifferentTypes(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) { | ||
return BoolType, nil | ||
}) | ||
|
||
merger.Add(func(l *ObjectType, r *PrimitiveType) (Type, error) { | ||
return StringType, nil | ||
}) | ||
|
||
result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
result, err = merger.Merge(IntType, NewObjectType()) // should hit fallback | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(BoolType)) | ||
} | ||
|
||
func TestCanMergeWithGenericTypeArgument(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) { | ||
return BoolType, nil | ||
}) | ||
|
||
merger.Add(func(l *ObjectType, r Type) (Type, error) { | ||
return StringType, nil | ||
}) | ||
|
||
result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
result, err = merger.Merge(IntType, NewObjectType()) // should hit fallback | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(BoolType)) | ||
} | ||
|
||
func TestCanMergeWithUnorderedMerger(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(func(_ interface{}, l, r Type) (Type, error) { | ||
return BoolType, nil | ||
}) | ||
|
||
merger.AddUnordered(func(l *ObjectType, r Type) (Type, error) { | ||
return StringType, nil | ||
}) | ||
|
||
result, err := merger.Merge(NewObjectType(), IntType) // shouldn't hit fallback | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
result, err = merger.Merge(IntType, NewObjectType()) // shouldn't hit fallback either | ||
g.Expect(err).To(Not(HaveOccurred())) | ||
g.Expect(result).To(Equal(StringType)) | ||
} | ||
|
||
var leftFallback MergerFunc = func(ctx interface{}, left, right Type) (Type, error) { return left, nil } | ||
|
||
func TestAddPanicsWhenPassedANonFunction(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(leftFallback) | ||
|
||
g.Expect(func() { merger.Add(123) }).To(PanicWith("merger must be a function")) | ||
} | ||
|
||
func TestMergerFuncMustTakeTwoOrThreeArguments(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(leftFallback) | ||
|
||
msg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])" | ||
|
||
g.Expect(func() { merger.Add(func() (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(x Type) (Type, error) { return x, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(x, _, _, _ Type) (Type, error) { return x, nil }) }).To(PanicWith(msg)) | ||
} | ||
|
||
func TestMergerFuncMustTakeTypesAssignableToTypeAsArguments(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(leftFallback) | ||
|
||
msg := "merger must take take arguments of type (left [some Type], right [some Type]) or (ctx X, left [some Type], right [some Type])" | ||
|
||
// left side wrong | ||
g.Expect(func() { merger.Add(func(_ int, _ Type) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ int, _ EnumType) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ interface{}, x int, _ Type) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ interface{}, _ int, _ EnumType) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
|
||
// right side wrong | ||
g.Expect(func() { merger.Add(func(_ Type, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ EnumType, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ interface{}, _ Type, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ interface{}, _ EnumType, _ int) (Type, error) { return nil, nil }) }).To(PanicWith(msg)) | ||
} | ||
|
||
func TestFuncMustTakeTypesAssignableToTypeAsArguments(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(leftFallback) | ||
|
||
msg := "merger must return (Type, error)" | ||
|
||
g.Expect(func() { merger.Add(func(_, _ Type) Type { return nil }) }).To(PanicWith(msg)) | ||
g.Expect(func() { merger.Add(func(_ interface{}, _, _ Type) Type { return nil }) }).To(PanicWith(msg)) | ||
} | ||
|
||
func TestMergeReturnsNonNilSide(t *testing.T) { | ||
g := NewGomegaWithT(t) | ||
|
||
merger := NewTypeMerger(leftFallback) | ||
|
||
var ctx interface{} = nil | ||
|
||
// left side | ||
result, err := merger.Merge(StringType, nil) | ||
g.Expect(err).ToNot(HaveOccurred()) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
result, err = merger.MergeWithContext(ctx, StringType, nil) | ||
g.Expect(err).ToNot(HaveOccurred()) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
// right side | ||
result, err = merger.Merge(nil, StringType) | ||
g.Expect(err).ToNot(HaveOccurred()) | ||
g.Expect(result).To(Equal(StringType)) | ||
|
||
result, err = merger.MergeWithContext(ctx, nil, StringType) | ||
g.Expect(err).ToNot(HaveOccurred()) | ||
g.Expect(result).To(Equal(StringType)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.