diff --git a/main.go b/main.go index c3ce2ec..ec76073 100644 --- a/main.go +++ b/main.go @@ -127,6 +127,10 @@ func main() { "help", false, "Prints this usage help to the user.") + timeoutPtr := flag.Int( + "timeout", + 5, + "Timeout in seconds for the connection to the server.") flag.Parse() color.NoColor = *noColor if (*connectPtr == "" && *listenPtr == "" && !*versionPtr) || *helpPtr { @@ -146,13 +150,13 @@ func main() { if *connectPtr != "" { address := formatAddress(*connectPtr, tscanner.ServerScan) var err error - if report, err = tscanner.Scan(address, tscanner.ServerScan, true); err != nil { + if report, err = tscanner.ScanWithTimeout(address, tscanner.ServerScan, true, *timeoutPtr); err != nil { panic(err) } } else if *listenPtr != "" { address := formatAddress(*listenPtr, tscanner.ClientScan) var err error - if report, err = tscanner.Scan(address, tscanner.ClientScan, true); err != nil { + if report, err = tscanner.ScanWithTimeout(address, tscanner.ClientScan, true, *timeoutPtr); err != nil { panic(err) } } diff --git a/tscanner/tscanner.go b/tscanner/tscanner.go index 50a9eb0..c0b7eec 100644 --- a/tscanner/tscanner.go +++ b/tscanner/tscanner.go @@ -7,11 +7,13 @@ import ( "encoding/binary" "encoding/json" "fmt" - "golang.org/x/exp/slices" "io" "net" "os" "strings" + "time" + + "golang.org/x/exp/slices" ) // ScanMode describes a scan mode for the scanner. @@ -60,10 +62,17 @@ func (report *Report) MarshalJSON() ([]byte, error) { // Scan performs a vulnerability scan to check whether the remote peer is likely to be vulnerable against prefix truncation. func Scan(address string, scanMode ScanMode, verbose bool) (*Report, error) { + return ScanWithTimeout(address, scanMode, verbose, 0) +} + +// ScanWithTimeout performs a vulnerability scan with configurable timeout to check whether the remote peer +// is likely to be vulnerable against prefix truncation. +func ScanWithTimeout(address string, scanMode ScanMode, verbose bool, timeout int) (*Report, error) { var conn net.Conn if scanMode == ServerScan { var err error - if conn, err = net.Dial("tcp", address); err != nil { + dialer := net.Dialer{Timeout: time.Duration(timeout) * time.Second} + if conn, err = dialer.Dial("tcp", address); err != nil { return nil, err } } else if scanMode == ClientScan {