Skip to content

Commit

Permalink
Cancel and timeout contexts are handled correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
bengarrett committed Apr 12, 2021
1 parent 4f24abf commit 6d167a5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 33 deletions.
47 changes: 23 additions & 24 deletions lib/myipio/myipio.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"strings"
)

Expand All @@ -34,61 +33,61 @@ var (
ErrNoIPv4 = errors.New("ip address is not ipv4")
ErrInvalid = errors.New("ip address is invalid")
ErrStatus = errors.New("unusual my-ip.io server response")
)

const (
domain = "api4.my-ip.io"
link = "https://api4.my-ip.io/ip.json"
link = "https://api4.my-ip.io/ip.json"
)

const domain = "api4.my-ip.io"

// IPv4 returns the Internet facing IP address of the free my-ip.io service.
func IPv4(ctx context.Context, cancel context.CancelFunc) (string, error) {
s, err := request(ctx, cancel, link)
r, err := request(ctx, cancel, link)
if err == nil && ctx.Err() == context.Canceled {
return "", nil
}
if err != nil {
if _, ok := err.(*url.Error); ok {
if strings.Contains(err.Error(), "context deadline exceeded") {
fmt.Printf("\n%s: timeout\n", domain)
return "", nil
}
fmt.Printf("\n%s: %s\n", domain, err)
switch errors.Unwrap(err) {
case context.DeadlineExceeded:
fmt.Printf("\n%s: timeout\n", domain)
return "", nil
default:
return "", fmt.Errorf("%s error: %s", domain, err)
}
return "", fmt.Errorf("%s error: %s", domain, err)
}

return s, nil
if ok, err := r.valid(); !ok {
return r.IP, err
}

return r.IP, nil
}

func request(ctx context.Context, cancel context.CancelFunc, url string) (string, error) {
func request(ctx context.Context, cancel context.CancelFunc, url string) (Result, error) {
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
return Result{}, err
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
return Result{}, err
}
defer resp.Body.Close()

//log.Printf("\nReceived %d from %s\n", resp.StatusCode, url)

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s, %w", strings.ToLower(resp.Status), ErrStatus)
return Result{}, fmt.Errorf("%s, %w", strings.ToLower(resp.Status), ErrStatus)
}

r, err := parse(resp.Body)
if err != nil {
return "", err
}

if ok, err := r.valid(); !ok {
return r.IP, err
return Result{}, err
}

return r.IP, nil
return r, nil
}

func parse(r io.Reader) (Result, error) {
Expand Down
36 changes: 27 additions & 9 deletions lib/myipio/myipio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,31 @@ func BenchmarkRequest(b *testing.B) {
fmt.Println(s)
}

func TestIPv4(t *testing.T) {
rc, rto := context.WithTimeout(context.Background(), 5*time.Second)
wantS, _ := request(rc, rto, link)
func TestTimeout(t *testing.T) {
ctx, timeout := context.WithTimeout(context.Background(), 0*time.Second)
if _, err := IPv4(ctx, timeout); !errors.Is(err, nil) {
t.Errorf("IPv4() = %v, want %v", err, nil)
}
}

ctx, timeout := context.WithTimeout(context.Background(), 5*time.Second)
if gotS, _ := IPv4(ctx, timeout); gotS != wantS {
t.Errorf("IPv4() = %v, want %v", gotS, wantS)
func TestCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
want := context.Canceled
_, err := IPv4(ctx, cancel)
if err == nil {
t.Errorf("IPv4() error = %v, want error string", err)
}
if !errors.Is(ctx.Err(), want) {
t.Errorf("IPv4() context.error = %v, want %v", err, want)
}
}

func TestError(t *testing.T) {
link = "invalid url"
ctx, timeout := context.WithTimeout(context.Background(), 30*time.Second)
if _, err := IPv4(ctx, timeout); errors.Is(err, nil) {
t.Errorf("IPv4() = %v, want an error", err)
}
}

Expand All @@ -38,8 +56,8 @@ func Test_request(t *testing.T) {
}{
{"empty", "", false, "unsupported protocol scheme"},
{"html", "https://example.com", false, "invalid character"},
{"404", link + "/abcdef", false, "404 not found"},
{"okay", link, true, ""},
{"404", "https://api4.my-ip.io/ip.json/abcdef", false, "404 not found"},
{"okay", "https://api4.my-ip.io/ip.json", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -48,7 +66,7 @@ func Test_request(t *testing.T) {
if err != nil && tt.wantErr != "" && !strings.Contains(fmt.Sprint(err), tt.wantErr) {
t.Errorf("get() error = %v, want %v", err, tt.wantErr)
}
if bool(gotS != "") != tt.isValid {
if bool(gotS.IP != "") != tt.isValid {
t.Errorf("get() = %v, want an ip addr: %v", gotS, tt.isValid)
}
})
Expand Down

0 comments on commit 6d167a5

Please # to comment.