diff --git a/lib/myipio/myipio.go b/lib/myipio/myipio.go index beb85c8..9f5b472 100644 --- a/lib/myipio/myipio.go +++ b/lib/myipio/myipio.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "net/url" "strings" ) @@ -34,61 +33,61 @@ var ( ErrNoIPv4 = errors.New("ip address is not ipv4") ErrInvalid = errors.New("ip address is invalid") ErrStatus = errors.New("unusual my-ip.io server response") -) -const ( - domain = "api4.my-ip.io" - link = "https://api4.my-ip.io/ip.json" + link = "https://api4.my-ip.io/ip.json" ) +const domain = "api4.my-ip.io" + // IPv4 returns the Internet facing IP address of the free my-ip.io service. func IPv4(ctx context.Context, cancel context.CancelFunc) (string, error) { - s, err := request(ctx, cancel, link) + r, err := request(ctx, cancel, link) + if err == nil && ctx.Err() == context.Canceled { + return "", nil + } if err != nil { - if _, ok := err.(*url.Error); ok { - if strings.Contains(err.Error(), "context deadline exceeded") { - fmt.Printf("\n%s: timeout\n", domain) - return "", nil - } - fmt.Printf("\n%s: %s\n", domain, err) + switch errors.Unwrap(err) { + case context.DeadlineExceeded: + fmt.Printf("\n%s: timeout\n", domain) return "", nil + default: + return "", fmt.Errorf("%s error: %s", domain, err) } - return "", fmt.Errorf("%s error: %s", domain, err) } - return s, nil + if ok, err := r.valid(); !ok { + return r.IP, err + } + + return r.IP, nil } -func request(ctx context.Context, cancel context.CancelFunc, url string) (string, error) { +func request(ctx context.Context, cancel context.CancelFunc, url string) (Result, error) { defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return "", err + return Result{}, err } resp, err := http.DefaultClient.Do(req) if err != nil { - return "", err + return Result{}, err } defer resp.Body.Close() //log.Printf("\nReceived %d from %s\n", resp.StatusCode, url) if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("%s, %w", strings.ToLower(resp.Status), ErrStatus) + return Result{}, fmt.Errorf("%s, %w", strings.ToLower(resp.Status), ErrStatus) } r, err := parse(resp.Body) if err != nil { - return "", err - } - - if ok, err := r.valid(); !ok { - return r.IP, err + return Result{}, err } - return r.IP, nil + return r, nil } func parse(r io.Reader) (Result, error) { diff --git a/lib/myipio/myipio_test.go b/lib/myipio/myipio_test.go index c04602e..3beea1b 100644 --- a/lib/myipio/myipio_test.go +++ b/lib/myipio/myipio_test.go @@ -19,13 +19,31 @@ func BenchmarkRequest(b *testing.B) { fmt.Println(s) } -func TestIPv4(t *testing.T) { - rc, rto := context.WithTimeout(context.Background(), 5*time.Second) - wantS, _ := request(rc, rto, link) +func TestTimeout(t *testing.T) { + ctx, timeout := context.WithTimeout(context.Background(), 0*time.Second) + if _, err := IPv4(ctx, timeout); !errors.Is(err, nil) { + t.Errorf("IPv4() = %v, want %v", err, nil) + } +} - ctx, timeout := context.WithTimeout(context.Background(), 5*time.Second) - if gotS, _ := IPv4(ctx, timeout); gotS != wantS { - t.Errorf("IPv4() = %v, want %v", gotS, wantS) +func TestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + want := context.Canceled + _, err := IPv4(ctx, cancel) + if err == nil { + t.Errorf("IPv4() error = %v, want error string", err) + } + if !errors.Is(ctx.Err(), want) { + t.Errorf("IPv4() context.error = %v, want %v", err, want) + } +} + +func TestError(t *testing.T) { + link = "invalid url" + ctx, timeout := context.WithTimeout(context.Background(), 30*time.Second) + if _, err := IPv4(ctx, timeout); errors.Is(err, nil) { + t.Errorf("IPv4() = %v, want an error", err) } } @@ -38,8 +56,8 @@ func Test_request(t *testing.T) { }{ {"empty", "", false, "unsupported protocol scheme"}, {"html", "https://example.com", false, "invalid character"}, - {"404", link + "/abcdef", false, "404 not found"}, - {"okay", link, true, ""}, + {"404", "https://api4.my-ip.io/ip.json/abcdef", false, "404 not found"}, + {"okay", "https://api4.my-ip.io/ip.json", true, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -48,7 +66,7 @@ func Test_request(t *testing.T) { if err != nil && tt.wantErr != "" && !strings.Contains(fmt.Sprint(err), tt.wantErr) { t.Errorf("get() error = %v, want %v", err, tt.wantErr) } - if bool(gotS != "") != tt.isValid { + if bool(gotS.IP != "") != tt.isValid { t.Errorf("get() = %v, want an ip addr: %v", gotS, tt.isValid) } })