diff --git a/README.md b/README.md index 6676fe9..92c4b64 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Duration Check A Go linter to detect cases where two `time.Duration` values are being multiplied in possibly erroneous ways. -For example, consider the following snippet: +For example, consider the following (highly contrived) function: ```go func waitFor(someDuration time.Duration) { @@ -12,11 +12,13 @@ func waitFor(someDuration time.Duration) { } ``` -Although the above code would compile without any errors, the behaviour is most likely to be incorrect. A caller would -reasonably expect `waitFor(5 * time.Seconds)` to wait for ~5 seconds but they would end up waiting for ~1,388,889 hours. +Although the above code would compile without any errors, its runtime behaviour would almost certainly be incorrect. +A caller would reasonably expect `waitFor(5 * time.Seconds)` to wait for ~5 seconds but they would actually end up +waiting for ~1,388,889 hours. -A majority of these problems would be spotted almost immediately but some could still slip through unnoticed. Hopefully -this linter will help catch those rare cases before they cause a production issue. +The above example is just for illustration purposes only. The problem is glaringly obvious in such a simple function +and even the greenest Gopher would discover the issue immediately. However, imagine a much more complicated function +with many more lines and it is not inconceivable that such logic errors could go unnoticed. See the [test cases](testdata/src/a/a.go) for more examples of the types of errors detected by the linter. diff --git a/durationcheck.go b/durationcheck.go index da1db5f..33899c4 100644 --- a/durationcheck.go +++ b/durationcheck.go @@ -80,15 +80,19 @@ func isDuration(x types.Type) bool { // isUnacceptableExpr returns true if the argument is not an acceptable time.Duration expression func isUnacceptableExpr(pass *analysis.Pass, expr ast.Expr) bool { switch e := expr.(type) { - case *ast.BasicLit: // constants are acceptable + case *ast.BasicLit: return false - case *ast.CallExpr: // explicit casting of constants such as `time.Duration(10)` is acceptable + case *ast.CallExpr: return !isAcceptableCast(pass, e) + case *ast.BinaryExpr: + return !isAcceptableNestedExpr(pass, e) + case *ast.UnaryExpr: + return !isAcceptableNestedExpr(pass, e) } return true } -// isAcceptableCast returns true if the argument is a constant expression cast to time.Duration +// isAcceptableCast returns true if the argument is an acceptable expression cast to time.Duration func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool { // check that there's a single argument if len(e.Args) != 1 { @@ -96,7 +100,7 @@ func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool { } // check that the argument is acceptable - if !isAcceptableCastArg(pass, e.Args[0]) { + if !isAcceptableNestedExpr(pass, e.Args[0]) { return false } @@ -106,6 +110,10 @@ func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool { return false } + return isDurationCast(selector) +} + +func isDurationCast(selector *ast.SelectorExpr) bool { pkg, ok := selector.X.(*ast.Ident) if !ok { return false @@ -118,16 +126,22 @@ func isAcceptableCast(pass *analysis.Pass, e *ast.CallExpr) bool { return selector.Sel.Name == "Duration" } -func isAcceptableCastArg(pass *analysis.Pass, n ast.Expr) bool { +func isAcceptableNestedExpr(pass *analysis.Pass, n ast.Expr) bool { switch e := n.(type) { case *ast.BasicLit: return true case *ast.BinaryExpr: - return isAcceptableCastArg(pass, e.X) && isAcceptableCastArg(pass, e.Y) - default: - argType := pass.TypesInfo.TypeOf(n) - return !isDuration(argType) + return isAcceptableNestedExpr(pass, e.X) && isAcceptableNestedExpr(pass, e.Y) + case *ast.UnaryExpr: + return isAcceptableNestedExpr(pass, e.X) + case *ast.Ident: + t := pass.TypesInfo.TypeOf(e) + return !isDuration(t) + case *ast.CallExpr: + t := pass.TypesInfo.TypeOf(e) + return !isDuration(t) } + return false } func formatNode(node ast.Node) string { @@ -140,8 +154,8 @@ func formatNode(node ast.Node) string { return buf.String() } -func printAST(node ast.Node) { - fmt.Printf(">>> %s\n", formatNode(node)) +func printAST(msg string, node ast.Node) { + fmt.Printf(">>> %s:\n%s\n\n\n", msg, formatNode(node)) ast.Fprint(os.Stdout, nil, node, nil) fmt.Println("--------------") } diff --git a/testdata/src/a/a.go b/testdata/src/a/a.go index be7a537..5016b69 100644 --- a/testdata/src/a/a.go +++ b/testdata/src/a/a.go @@ -6,8 +6,7 @@ import ( const timeout = 10 * time.Second -func multiplyTwoDurations() { - x := 30 * time.Second +func validCases() { y := 10 _ = time.Second * 30 @@ -20,6 +19,14 @@ func multiplyTwoDurations() { _ = time.Second * time.Duration(10+20*5) + _ = 2 * 24 * time.Hour + + _ = time.Hour * 2 * 24 + + _ = -1 * time.Hour + + _ = time.Hour * -1 + _ = time.Duration(y) * time.Second _ = time.Second * time.Duration(y) @@ -29,13 +36,17 @@ func multiplyTwoDurations() { _ = time.Millisecond * time.Duration(someDurationMillis()) _ = timeout / time.Millisecond +} - _ = timeout * time.Millisecond // want `Multiplication of durations` +func invalidCases() { + x := 30 * time.Second _ = x * time.Second // want `Multiplication of durations` _ = time.Second * x // want `Multiplication of durations` + _ = timeout * time.Millisecond // want `Multiplication of durations` + _ = someDuration() * time.Second // want `Multiplication of durations` _ = time.Millisecond * someDuration() // want `Multiplication of durations`