diff --git a/common/source.go b/common/source.go index 52377d93..090667af 100644 --- a/common/source.go +++ b/common/source.go @@ -39,6 +39,10 @@ type Source interface { // and second line, or EOF if there is only one line of source. LineOffsets() []int32 + // Macro calls returns the macro calls map containing the original + // expression from a macro replacement, keyed by Id. + MacroCalls() map[int64]*exprpb.Expr + // LocationOffset translates a Location to an offset. // Given the line and column of the Location returns the // Location's character offset in the Source, and a bool @@ -65,6 +69,7 @@ type sourceImpl struct { description string lineOffsets []int32 idOffsets map[int64]int32 + macroCalls map[int64]*exprpb.Expr } var _ runes.Buffer = &sourceImpl{} @@ -93,6 +98,7 @@ func NewStringSource(contents string, description string) Source { description: description, lineOffsets: offsets, idOffsets: map[int64]int32{}, + macroCalls: map[int64]*exprpb.Expr{}, } } @@ -103,6 +109,7 @@ func NewInfoSource(info *exprpb.SourceInfo) Source { description: info.GetLocation(), lineOffsets: info.GetLineOffsets(), idOffsets: info.GetPositions(), + macroCalls: info.GetMacroCalls(), } } @@ -121,6 +128,11 @@ func (s *sourceImpl) LineOffsets() []int32 { return s.lineOffsets } +// MacroCalls implements the Source interface method. +func (s *sourceImpl) MacroCalls() map[int64]*exprpb.Expr { + return s.macroCalls +} + // LocationOffset implements the Source interface method. func (s *sourceImpl) LocationOffset(location Location) (int32, bool) { if lineOffset, found := s.findLineOffset(location.Line()); found { diff --git a/parser/helper.go b/parser/helper.go index 98656eaf..7b10ff37 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -41,7 +41,8 @@ func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo { return &exprpb.SourceInfo{ Location: p.source.Description(), Positions: p.positions, - LineOffsets: p.source.LineOffsets()} + LineOffsets: p.source.LineOffsets(), + MacroCalls: p.source.MacroCalls()} } func (p *parserHelper) newLiteral(ctx interface{}, value *exprpb.Constant) *exprpb.Expr { @@ -207,6 +208,62 @@ func (p *parserHelper) getLocation(id int64) common.Location { return location } +// buildMacroCallArg iterates the expression and returns a new expression +// where all macros have been replaced by their IDs in MacroCalls +func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr { + resultExpr := &exprpb.Expr{Id: expr.GetId()} + if _, found := p.source.MacroCalls()[expr.GetId()]; found { + return resultExpr + } + + switch expr.ExprKind.(type) { + case *exprpb.Expr_CallExpr: + resultExpr.ExprKind = &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: expr.GetCallExpr().GetFunction(), + }, + } + resultExpr.GetCallExpr().Args = make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs())) + // Iterate the AST from `expr` recursively looking for macros. Because we are at most + // starting from the top level macro, this recursion is bounded by the size of the AST. This + // means that the depth check on the AST during parsing will catch recursion overflows + // before we get to here. + for index, arg := range expr.GetCallExpr().GetArgs() { + resultExpr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg) + } + return resultExpr + } + + return expr +} + +// addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target +// that are macros, their ID will be stored instead for later self-lookups. +func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) { + expr := &exprpb.Expr{ + Id: exprID, + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: function, + }, + }, + } + + if target != nil { + if _, found := p.source.MacroCalls()[target.GetId()]; found { + expr.GetCallExpr().Target = &exprpb.Expr{Id: target.GetId()} + } else { + expr.GetCallExpr().Target = target + } + } + + expr.GetCallExpr().Args = make([]*exprpb.Expr, len(args)) + for index, arg := range args { + expr.GetCallExpr().GetArgs()[index] = p.buildMacroCallArg(arg) + } + p.source.MacroCalls()[exprID] = expr +} + // balancer performs tree balancing on operators whose arguments are of equal precedence. // // The purpose of the balancer is to ensure a compact serialization format for the logical &&, || diff --git a/parser/macro.go b/parser/macro.go index de3fc438..baeddd94 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -285,7 +285,7 @@ func makeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - location := eh.OffsetLocation(args[0].Id) + location := eh.OffsetLocation(args[0].GetId()) return nil, &common.Error{ Message: "argument must be a simple name", Location: location} @@ -373,14 +373,14 @@ func makeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprp func extractIdent(e *exprpb.Expr) (string, bool) { switch e.ExprKind.(type) { case *exprpb.Expr_IdentExpr: - return e.GetIdentExpr().Name, true + return e.GetIdentExpr().GetName(), true } return "", false } func makeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok { - return eh.PresenceTest(s.SelectExpr.Operand, s.SelectExpr.Field), nil + return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil } return nil, &common.Error{Message: "invalid argument to has() macro"} } diff --git a/parser/options.go b/parser/options.go index 6ee50bd3..b50686a9 100644 --- a/parser/options.go +++ b/parser/options.go @@ -22,6 +22,7 @@ type options struct { errorRecoveryLimit int expressionSizeCodePointLimit int macros map[string]Macro + populateMacroCalls bool } // Option configures the behavior of the parser. @@ -92,3 +93,12 @@ func Macros(macros ...Macro) Option { return nil } } + +// PopulateMacroCalls ensures that the original call signatures replaced by expanded macros +// are preserved in the `SourceInfo` of parse result. +func PopulateMacroCalls(populateMacroCalls bool) Option { + return func(opts *options) error { + opts.populateMacroCalls = populateMacroCalls + return nil + } +} diff --git a/parser/parser.go b/parser/parser.go index 617f884a..11dcc0eb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -66,6 +66,7 @@ func NewParser(opts ...Option) (*Parser, error) { if p.expressionSizeCodePointLimit == -1 { p.expressionSizeCodePointLimit = int((^uint(0)) >> 1) } + // Bool is false by default, so populateMacroCalls will be false by default return p, nil } @@ -90,6 +91,7 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors maxRecursionDepth: p.maxRecursionDepth, errorRecoveryLimit: p.errorRecoveryLimit, errorRecoveryLookaheadTokenLimit: p.errorRecoveryTokenLookaheadLimit, + populateMacroCalls: p.populateMacroCalls, } buf, ok := source.(runes.Buffer) if !ok { @@ -278,6 +280,7 @@ type parser struct { maxRecursionDepth int errorRecoveryLimit int errorRecoveryLookaheadTokenLimit int + populateMacroCalls bool } var ( @@ -804,7 +807,7 @@ func (p *parser) extractQualifiedName(e *exprpb.Expr) (string, bool) { } switch e.ExprKind.(type) { case *exprpb.Expr_IdentExpr: - return e.GetIdentExpr().Name, true + return e.GetIdentExpr().GetName(), true case *exprpb.Expr_SelectExpr: s := e.GetSelectExpr() if prefix, found := p.extractQualifiedName(s.Operand); found { @@ -812,7 +815,7 @@ func (p *parser) extractQualifiedName(e *exprpb.Expr) (string, bool) { } } // TODO: Add a method to Source to get location from character offset. - location := p.helper.getLocation(e.Id) + location := p.helper.getLocation(e.GetId()) p.reportError(location, "expected a qualified name") return "", false } @@ -833,7 +836,7 @@ func (p *parser) reportError(ctx interface{}, format string, args ...interface{} location = ctx.(common.Location) case antlr.Token, antlr.ParserRuleContext: err := p.helper.newExpr(ctx) - location = p.helper.getLocation(err.Id) + location = p.helper.getLocation(err.GetId()) } err := p.helper.newExpr(ctx) // Provide arguments to the report error. @@ -893,5 +896,8 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, } return p.reportError(p.helper.getLocation(exprID), err.Message), true } + if p.populateMacroCalls { + p.helper.addMacroCall(expr.GetId(), function, target, args...) + } return expr, true } diff --git a/parser/parser_test.go b/parser/parser_test.go index b9d7af96..aa376e73 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -17,6 +17,8 @@ package parser import ( "fmt" "reflect" + "sort" + "strings" "testing" "github.com/google/cel-go/common" @@ -432,6 +434,9 @@ var testCases = []testInfo{ I: `has(m.f)`, P: `m^#2:*expr.Expr_IdentExpr#.f~test-only~^#4:*expr.Expr_SelectExpr#`, L: `m^#2[1,4]#.f~test-only~^#4[1,3]#`, + M: `has( + m^#2:*expr.Expr_IdentExpr#.f^#3:*expr.Expr_SelectExpr# + )^#4:has#`, }, { I: `m.exists(v, f)`, @@ -457,6 +462,10 @@ var testCases = []testInfo{ )^#10:*expr.Expr_CallExpr#, // Result __result__^#11:*expr.Expr_IdentExpr#)^#12:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.exists( + v^#3:*expr.Expr_IdentExpr#, + f^#4:*expr.Expr_IdentExpr# + )^#12:exists#`, }, { I: `m.all(v, f)`, @@ -480,6 +489,10 @@ var testCases = []testInfo{ )^#9:*expr.Expr_CallExpr#, // Result __result__^#10:*expr.Expr_IdentExpr#)^#11:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.all( + v^#3:*expr.Expr_IdentExpr#, + f^#4:*expr.Expr_IdentExpr# + )^#11:all#`, }, { I: `m.exists_one(v, f)`, @@ -508,6 +521,10 @@ var testCases = []testInfo{ __result__^#12:*expr.Expr_IdentExpr#, 1^#6:*expr.Constant_Int64Value# )^#13:*expr.Expr_CallExpr#)^#14:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.exists_one( + v^#3:*expr.Expr_IdentExpr#, + f^#4:*expr.Expr_IdentExpr# + )^#14:exists_one#`, }, { I: `m.map(v, f)`, @@ -531,6 +548,10 @@ var testCases = []testInfo{ )^#9:*expr.Expr_CallExpr#, // Result __result__^#5:*expr.Expr_IdentExpr#)^#10:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.map( + v^#3:*expr.Expr_IdentExpr#, + f^#4:*expr.Expr_IdentExpr# + )^#10:map#`, }, { @@ -559,6 +580,11 @@ var testCases = []testInfo{ )^#11:*expr.Expr_CallExpr#, // Result __result__^#6:*expr.Expr_IdentExpr#)^#12:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.map( + v^#3:*expr.Expr_IdentExpr#, + p^#4:*expr.Expr_IdentExpr#, + f^#5:*expr.Expr_IdentExpr# + )^#12:map#`, }, { @@ -587,6 +613,10 @@ var testCases = []testInfo{ )^#10:*expr.Expr_CallExpr#, // Result __result__^#5:*expr.Expr_IdentExpr#)^#11:*expr.Expr_ComprehensionExpr#`, + M: `m^#1:*expr.Expr_IdentExpr#.filter( + v^#3:*expr.Expr_IdentExpr#, + p^#4:*expr.Expr_IdentExpr# + )^#11:filter#`, }, // Tests from C++ parser @@ -1243,6 +1273,198 @@ var testCases = []testInfo{ | 0"""\""\"""\""\"""\""\"""\""\"""\"\"""\""\"""\""\"""\""\"""\"!\"""\""\"""\""\" | ..........^`, }, + // Macro Calls Tests + { + I: `x.filter(y, y.filter(z, z > 0))`, + P: `__comprehension__( + // Variable + y, + // Target + x^#1:*expr.Expr_IdentExpr#, + // Accumulator + __result__, + // Init + []^#18:*expr.Expr_ListExpr#, + // LoopCondition + true^#19:*expr.Constant_BoolValue#, + // LoopStep + _?_:_( + __comprehension__( + // Variable + z, + // Target + y^#4:*expr.Expr_IdentExpr#, + // Accumulator + __result__, + // Init + []^#11:*expr.Expr_ListExpr#, + // LoopCondition + true^#12:*expr.Constant_BoolValue#, + // LoopStep + _?_:_( + _>_( + z^#7:*expr.Expr_IdentExpr#, + 0^#9:*expr.Constant_Int64Value# + )^#8:*expr.Expr_CallExpr#, + _+_( + __result__^#10:*expr.Expr_IdentExpr#, + [ + z^#6:*expr.Expr_IdentExpr# + ]^#13:*expr.Expr_ListExpr# + )^#14:*expr.Expr_CallExpr#, + __result__^#10:*expr.Expr_IdentExpr# + )^#15:*expr.Expr_CallExpr#, + // Result + __result__^#10:*expr.Expr_IdentExpr#)^#16:*expr.Expr_ComprehensionExpr#, + _+_( + __result__^#17:*expr.Expr_IdentExpr#, + [ + y^#3:*expr.Expr_IdentExpr# + ]^#20:*expr.Expr_ListExpr# + )^#21:*expr.Expr_CallExpr#, + __result__^#17:*expr.Expr_IdentExpr# + )^#22:*expr.Expr_CallExpr#, + // Result + __result__^#17:*expr.Expr_IdentExpr#)^#23:*expr.Expr_ComprehensionExpr#`, + M: `x^#1:*expr.Expr_IdentExpr#.filter( + y^#3:*expr.Expr_IdentExpr#, + ^#16:filter# + )^#23:filter#, + y^#4:*expr.Expr_IdentExpr#.filter( + z^#6:*expr.Expr_IdentExpr#, + _>_( + z^#7:*expr.Expr_IdentExpr#, + 0^#9:*expr.Constant_Int64Value# + )^#8:*expr.Expr_CallExpr# + )^#16:filter#`, + }, + { + I: `has(a.b).filter(c, c)`, + P: `__comprehension__( + // Variable + c, + // Target + a^#2:*expr.Expr_IdentExpr#.b~test-only~^#4:*expr.Expr_SelectExpr#, + // Accumulator + __result__, + // Init + []^#9:*expr.Expr_ListExpr#, + // LoopCondition + true^#10:*expr.Constant_BoolValue#, + // LoopStep + _?_:_( + c^#7:*expr.Expr_IdentExpr#, + _+_( + __result__^#8:*expr.Expr_IdentExpr#, + [ + c^#6:*expr.Expr_IdentExpr# + ]^#11:*expr.Expr_ListExpr# + )^#12:*expr.Expr_CallExpr#, + __result__^#8:*expr.Expr_IdentExpr# + )^#13:*expr.Expr_CallExpr#, + // Result + __result__^#8:*expr.Expr_IdentExpr#)^#14:*expr.Expr_ComprehensionExpr#`, + M: `^#4:has#.filter( + c^#6:*expr.Expr_IdentExpr#, + c^#7:*expr.Expr_IdentExpr# + )^#14:filter#, + has( + a^#2:*expr.Expr_IdentExpr#.b^#3:*expr.Expr_SelectExpr# + )^#4:has#`, + }, + { + I: `x.filter(y, y.exists(z, has(z.a)) && y.exists(z, has(z.b)))`, + P: `__comprehension__( + // Variable + y, + // Target + x^#1:*expr.Expr_IdentExpr#, + // Accumulator + __result__, + // Init + []^#36:*expr.Expr_ListExpr#, + // LoopCondition + true^#37:*expr.Constant_BoolValue#, + // LoopStep + _?_:_( + _&&_( + __comprehension__( + // Variable + z, + // Target + y^#4:*expr.Expr_IdentExpr#, + // Accumulator + __result__, + // Init + false^#11:*expr.Constant_BoolValue#, + // LoopCondition + @not_strictly_false( + !_( + __result__^#12:*expr.Expr_IdentExpr# + )^#13:*expr.Expr_CallExpr# + )^#14:*expr.Expr_CallExpr#, + // LoopStep + _||_( + __result__^#15:*expr.Expr_IdentExpr#, + z^#8:*expr.Expr_IdentExpr#.a~test-only~^#10:*expr.Expr_SelectExpr# + )^#16:*expr.Expr_CallExpr#, + // Result + __result__^#17:*expr.Expr_IdentExpr#)^#18:*expr.Expr_ComprehensionExpr#, + __comprehension__( + // Variable + z, + // Target + y^#19:*expr.Expr_IdentExpr#, + // Accumulator + __result__, + // Init + false^#26:*expr.Constant_BoolValue#, + // LoopCondition + @not_strictly_false( + !_( + __result__^#27:*expr.Expr_IdentExpr# + )^#28:*expr.Expr_CallExpr# + )^#29:*expr.Expr_CallExpr#, + // LoopStep + _||_( + __result__^#30:*expr.Expr_IdentExpr#, + z^#23:*expr.Expr_IdentExpr#.b~test-only~^#25:*expr.Expr_SelectExpr# + )^#31:*expr.Expr_CallExpr#, + // Result + __result__^#32:*expr.Expr_IdentExpr#)^#33:*expr.Expr_ComprehensionExpr# + )^#34:*expr.Expr_CallExpr#, + _+_( + __result__^#35:*expr.Expr_IdentExpr#, + [ + y^#3:*expr.Expr_IdentExpr# + ]^#38:*expr.Expr_ListExpr# + )^#39:*expr.Expr_CallExpr#, + __result__^#35:*expr.Expr_IdentExpr# + )^#40:*expr.Expr_CallExpr#, + // Result + __result__^#35:*expr.Expr_IdentExpr#)^#41:*expr.Expr_ComprehensionExpr#`, + M: `x^#1:*expr.Expr_IdentExpr#.filter( + y^#3:*expr.Expr_IdentExpr#, + _&&_( + ^#18:exists#, + ^#33:exists# + )^#34:*expr.Expr_CallExpr# + )^#41:filter#, + y^#19:*expr.Expr_IdentExpr#.exists( + z^#21:*expr.Expr_IdentExpr#, + ^#25:has# + )^#33:exists#, + has( + z^#23:*expr.Expr_IdentExpr#.b^#24:*expr.Expr_SelectExpr# + )^#25:has#, + y^#4:*expr.Expr_IdentExpr#.exists( + z^#6:*expr.Expr_IdentExpr#, + ^#10:has# + )^#18:exists#, + has( + z^#8:*expr.Expr_IdentExpr#.a^#9:*expr.Expr_SelectExpr# + )^#10:has#`, + }, } type testInfo struct { @@ -1257,6 +1479,9 @@ type testInfo struct { // L contains the expected source adorned debug output of the expression tree. L string + + // M contains the expected adorned debug output of the macro calls map + M string } type metadata interface { @@ -1264,21 +1489,27 @@ type metadata interface { } type kindAndIDAdorner struct { + sourceInfo *exprpb.SourceInfo } func (k *kindAndIDAdorner) GetMetadata(elem interface{}) string { switch elem.(type) { case *exprpb.Expr: e := elem.(*exprpb.Expr) + if k.sourceInfo != nil { + if val, found := k.sourceInfo.MacroCalls[e.GetId()]; found { + return fmt.Sprintf("^#%d:%s#", e.Id, val.GetCallExpr().GetFunction()) + } + } var valType interface{} = e.ExprKind switch valType.(type) { case *exprpb.Expr_ConstExpr: - valType = e.GetConstExpr().ConstantKind + valType = e.GetConstExpr().GetConstantKind() } - return fmt.Sprintf("^#%d:%s#", e.Id, reflect.TypeOf(valType)) + return fmt.Sprintf("^#%d:%s#", e.GetId(), reflect.TypeOf(valType)) case *exprpb.Expr_CreateStruct_Entry: entry := elem.(*exprpb.Expr_CreateStruct_Entry) - return fmt.Sprintf("^#%d:%s#", entry.Id, "*expr.Expr_CreateStruct_Entry") + return fmt.Sprintf("^#%d:%s#", entry.GetId(), "*expr.Expr_CreateStruct_Entry") } return "" } @@ -1290,9 +1521,9 @@ type locationAdorner struct { var _ metadata = &locationAdorner{} func (l *locationAdorner) GetLocation(exprID int64) (common.Location, bool) { - if pos, found := l.sourceInfo.Positions[exprID]; found { + if pos, found := l.sourceInfo.GetPositions()[exprID]; found { var line = 1 - for _, lineOffset := range l.sourceInfo.LineOffsets { + for _, lineOffset := range l.sourceInfo.GetLineOffsets() { if lineOffset > pos { break } else { @@ -1301,7 +1532,7 @@ func (l *locationAdorner) GetLocation(exprID int64) (common.Location, bool) { } var column = pos if line > 1 { - column = pos - l.sourceInfo.LineOffsets[line-2] + column = pos - l.sourceInfo.GetLineOffsets()[line-2] } return common.NewLocation(line, int(column)), true } @@ -1312,20 +1543,39 @@ func (l *locationAdorner) GetMetadata(elem interface{}) string { var elemID int64 switch elem.(type) { case *exprpb.Expr: - elemID = elem.(*exprpb.Expr).Id + elemID = elem.(*exprpb.Expr).GetId() case *exprpb.Expr_CreateStruct_Entry: - elemID = elem.(*exprpb.Expr_CreateStruct_Entry).Id + elemID = elem.(*exprpb.Expr_CreateStruct_Entry).GetId() } location, _ := l.GetLocation(elemID) return fmt.Sprintf("^#%d[%d,%d]#", elemID, location.Line(), location.Column()) } +func convertMacroCallsToString(source *exprpb.SourceInfo) string { + keys := make([]int64, len(source.GetMacroCalls())) + adornedStrings := make([]string, len(source.GetMacroCalls())) + i := 0 + for k := range source.GetMacroCalls() { + keys[i] = k + i++ + } + // Sort the keys in descending order to create a stable ordering for tests and improve readability. + sort.Slice(keys, func(i, j int) bool { return keys[i] > keys[j] }) + i = 0 + for _, key := range keys { + adornedStrings[i] = debug.ToAdornedDebugString(source.GetMacroCalls()[int64(key)], &kindAndIDAdorner{sourceInfo: source}) + i++ + } + return strings.Join(adornedStrings, ",\n") +} + func TestParse(t *testing.T) { p, err := NewParser( Macros(AllMacros...), MaxRecursionDepth(32), ErrorRecoveryLimit(4), ErrorRecoveryLookaheadTokenLimit(4), + PopulateMacroCalls(true), ) if err != nil { t.Fatal(err) @@ -1353,16 +1603,23 @@ func TestParse(t *testing.T) { } else if tc.E != "" { tt.Fatalf("Expected error not thrown: '%s'", tc.E) } - + failureDisplayMethod := fmt.Sprintf("Parse(\"%s\")", tc.I) actualWithKind := debug.ToAdornedDebugString(expression.Expr, &kindAndIDAdorner{}) if !test.Compare(actualWithKind, tc.P) { - tt.Fatal(test.DiffMessage("structure", actualWithKind, tc.P)) + tt.Fatal(test.DiffMessage(fmt.Sprintf("Structure - %s", failureDisplayMethod), actualWithKind, tc.P)) } if tc.L != "" { - actualWithLocation := debug.ToAdornedDebugString(expression.Expr, &locationAdorner{expression.SourceInfo}) + actualWithLocation := debug.ToAdornedDebugString(expression.Expr, &locationAdorner{expression.GetSourceInfo()}) if !test.Compare(actualWithLocation, tc.L) { - tt.Fatal(test.DiffMessage("location", actualWithLocation, tc.L)) + tt.Fatal(test.DiffMessage(fmt.Sprintf("Location - %s", failureDisplayMethod), actualWithLocation, tc.L)) + } + } + + if tc.M != "" { + actualAdornedMacroCalls := convertMacroCallsToString(expression.GetSourceInfo()) + if !test.Compare(actualAdornedMacroCalls, tc.M) { + tt.Fatal(test.DiffMessage(fmt.Sprintf("Macro Calls - %s", failureDisplayMethod), actualAdornedMacroCalls, tc.M)) } } }) diff --git a/test/compare.go b/test/compare.go index b90f3278..0b5044ec 100644 --- a/test/compare.go +++ b/test/compare.go @@ -35,11 +35,5 @@ func Compare(a string, e string) bool { // DiffMessage creates a diff dump message for test failures. func DiffMessage(context string, actual interface{}, expected interface{}) string { - result := fmt.Sprintf("FAILURE(%s)\n", context) - result += "\n===== ACTUAL =====\n" - result += strings.TrimSpace(fmt.Sprintf("%v", actual)) - result += "\n==== EXPECTED ====\n" - result += strings.TrimSpace(fmt.Sprintf("%v", expected)) - result += "\n==================\n" - return result + return fmt.Sprintf("%s: \ngot %q, \nwanted %q", context, actual, expected) }