diff --git a/progressbar.go b/progressbar.go index 0ccec7d..460360b 100644 --- a/progressbar.go +++ b/progressbar.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "math" "net/http" "os" @@ -449,17 +448,15 @@ func NewOptions64(max int64, options ...Option) *ProgressBar { go func() { ticker := time.NewTicker(b.config.spinnerChangeInterval) defer ticker.Stop() - for { - select { - case <-ticker.C: - if b.IsFinished() { - return - } - if b.IsStarted() { - b.lock.Lock() - b.render() - b.lock.Unlock() - } + + for range ticker.C { + if b.IsFinished() { + return + } + if b.IsStarted() { + b.lock.Lock() + b.render() + b.lock.Unlock() } } }() @@ -1014,27 +1011,48 @@ func (p *ProgressBar) State() State { // StartHTTPServer starts an HTTP server dedicated to serving progress bar updates. This allows you to // display the status in various UI elements, such as an OS status bar with an `xbar` extension. -// It is recommended to run this function in a separate goroutine to avoid blocking the main thread. +// When the progress bar is finished, call `server.Shutdown()` or `server.Close()` to shut it down manually. // // hostPort specifies the address and port to bind the server to, for example, "0.0.0.0:19999". -func (p *ProgressBar) StartHTTPServer(hostPort string) { - // for advanced users, we can return the data as json - http.HandleFunc("/state", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/json") - // since the state is a simple struct, we can just ignore the error +func (p *ProgressBar) StartHTTPServer(hostPort string) *http.Server { + mux := http.NewServeMux() + + // register routes + mux.HandleFunc("/state", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") bs, _ := json.Marshal(p.State()) w.Write(bs) }) - // for others, we just return the description in a plain text format - http.HandleFunc("/desc", func(w http.ResponseWriter, r *http.Request) { + + mux.HandleFunc("/desc", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") + state := p.State() fmt.Fprintf(w, "%d/%d, %.2f%%, %s left", - p.State().CurrentNum, p.State().Max, p.State().CurrentPercent*100, - (time.Second * time.Duration(p.State().SecondsLeft)).String(), + state.CurrentNum, state.Max, state.CurrentPercent*100, + (time.Second * time.Duration(state.SecondsLeft)).String(), ) }) - log.Fatal(http.ListenAndServe(hostPort, nil)) + + // create the server instance + server := &http.Server{ + Addr: hostPort, + Handler: mux, + } + + // start the server in a goroutine and ignore errors + go func() { + defer func() { + if err := recover(); err != nil { + fmt.Println("encounter panic: ", err) + } + }() + + _ = server.ListenAndServe() + }() + + // return the server instance for use by the caller + return server } // regex matching ansi escape codes diff --git a/progressbar_test.go b/progressbar_test.go index 1a4aad7..fd6ac3a 100644 --- a/progressbar_test.go +++ b/progressbar_test.go @@ -2,6 +2,7 @@ package progressbar import ( "bytes" + "context" "crypto/md5" "encoding/hex" "encoding/json" @@ -471,7 +472,7 @@ func TestOptionSetTheme(t *testing.T) { bar.RenderBlank() result := strings.TrimSpace(buf.String()) expect := "0% >----------<" - if strings.Index(result, expect) == -1 { + if !strings.Contains(result, expect) { t.Errorf("Render miss-match\nResult: '%s'\nExpect: '%s'\n%+v", result, expect, bar) } buf.Reset() @@ -487,7 +488,7 @@ func TestOptionSetTheme(t *testing.T) { bar.Finish() result = strings.TrimSpace(buf.String()) expect = "100% >##########<" - if strings.Index(result, expect) == -1 { + if !strings.Contains(result, expect) { t.Errorf("Render miss-match\nResult: '%s'\nExpect: '%s'\n%+v", result, expect, bar) } } @@ -506,7 +507,7 @@ func TestOptionSetThemeFilled(t *testing.T) { bar.RenderBlank() result := strings.TrimSpace(buf.String()) expect := "0% >----------<" - if strings.Index(result, expect) == -1 { + if !strings.Contains(result, expect) { t.Errorf("Render miss-match\nResult: '%s'\nExpect: '%s'\n%+v", result, expect, bar) } buf.Reset() @@ -522,7 +523,7 @@ func TestOptionSetThemeFilled(t *testing.T) { bar.Finish() result = strings.TrimSpace(buf.String()) expect = "100% ]##########[" - if strings.Index(result, expect) == -1 { + if !strings.Contains(result, expect) { t.Errorf("Render miss-match\nResult: '%s'\nExpect: '%s'\n%+v", result, expect, bar) } } @@ -1067,7 +1068,7 @@ func TestOptionSetSpinnerChangeIntervalZero(t *testing.T) { bar.lock.Lock() s, _ := vt.String() bar.lock.Unlock() - s = strings.TrimSpace(s) + _ = strings.TrimSpace(s) } for i := range actuals { assert.Equal(t, expected[i], actuals[i]) @@ -1130,7 +1131,7 @@ func TestStartHTTPServer(t *testing.T) { bar.Add(1) hostPort := "localhost:9696" - go bar.StartHTTPServer(hostPort) + svr := bar.StartHTTPServer(hostPort) // check plain text resp, err := http.Get(fmt.Sprintf("http://%s/desc", hostPort)) @@ -1162,4 +1163,19 @@ func TestStartHTTPServer(t *testing.T) { if result.Max != bar.State().Max || result.CurrentNum != bar.State().CurrentNum { t.Errorf("wrong state: %v", result) } + + // shutdown server + err = svr.Shutdown(context.Background()) + if err != nil { + t.Errorf("shutdown server failed: %v", err) + } + + // start new bar server + bar = Default(10, "test") + bar.Add(1) + svr = bar.StartHTTPServer(hostPort) + err = svr.Close() + if err != nil { + t.Errorf("shutdown server failed: %v", err) + } }