diff --git a/internal/daemon/api.go b/internal/daemon/api.go index 2a994707d..3ebbf370f 100644 --- a/internal/daemon/api.go +++ b/internal/daemon/api.go @@ -70,10 +70,10 @@ var api = []*Command{{ UserOK: true, POST: v1PostLayers, }, { - Path: "/v1/files", - UserOK: true, - GET: v1GetFiles, - POST: v1PostFiles, + Path: "/v1/files", + AdminOnly: true, + GET: v1GetFiles, + POST: v1PostFiles, }, { Path: "/v1/logs", UserOK: true, @@ -83,9 +83,9 @@ var api = []*Command{{ UserOK: true, POST: v1PostExec, }, { - Path: "/v1/tasks/{task-id}/websocket/{websocket-id}", - UserOK: true, - GET: v1GetTaskWebsocket, + Path: "/v1/tasks/{task-id}/websocket/{websocket-id}", + AdminOnly: true, + GET: v1GetTaskWebsocket, }, { Path: "/v1/signals", UserOK: true, diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 870b5c36b..704b8b14a 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -18,12 +18,15 @@ import ( "bytes" "encoding/json" "fmt" + "io" "io/ioutil" "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" + "strings" "sync" "syscall" "testing" @@ -1148,3 +1151,90 @@ services: c.Assert(tasks, HasLen, 1) c.Check(tasks[0].Kind(), Equals, "stop") } + +func (s *daemonSuite) TestAPIAccessLevels(c *C) { + _ = s.newDaemon(c) + + tests := []struct { + method string + path string + body string + uid int // -1 means no peer cred user + status int + }{ + {"GET", "/v1/system-info", ``, -1, http.StatusOK}, + + {"GET", "/v1/health", ``, -1, http.StatusOK}, + + {"GET", "/v1/warnings", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/warnings", ``, 42, http.StatusOK}, + {"GET", "/v1/warnings", ``, 0, http.StatusOK}, + {"POST", "/v1/warnings", ``, -1, http.StatusUnauthorized}, + {"POST", "/v1/warnings", ``, 42, http.StatusUnauthorized}, + {"POST", "/v1/warnings", ``, 0, http.StatusBadRequest}, + + {"GET", "/v1/changes", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/changes", ``, 42, http.StatusOK}, + {"GET", "/v1/changes", ``, 0, http.StatusOK}, + + {"GET", "/v1/services", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/services", ``, 42, http.StatusOK}, + {"GET", "/v1/services", ``, 0, http.StatusOK}, + {"POST", "/v1/services", ``, -1, http.StatusUnauthorized}, + {"POST", "/v1/services", ``, 42, http.StatusUnauthorized}, + {"POST", "/v1/services", ``, 0, http.StatusBadRequest}, + + {"POST", "/v1/layers", ``, -1, http.StatusUnauthorized}, + {"POST", "/v1/layers", ``, 42, http.StatusUnauthorized}, + {"POST", "/v1/layers", ``, 0, http.StatusBadRequest}, + + {"GET", "/v1/files?action=list&path=/", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/files?action=list&path=/", ``, 42, http.StatusUnauthorized}, // even reading files requires admin + {"GET", "/v1/files?action=list&path=/", ``, 0, http.StatusOK}, + {"POST", "/v1/files", `{}`, -1, http.StatusUnauthorized}, + {"POST", "/v1/files", `{}`, 42, http.StatusUnauthorized}, + {"POST", "/v1/files", `{}`, 0, http.StatusBadRequest}, + + {"GET", "/v1/logs", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/logs", ``, 42, http.StatusOK}, + {"GET", "/v1/logs", ``, 0, http.StatusOK}, + + {"POST", "/v1/exec", `{}`, -1, http.StatusUnauthorized}, + {"POST", "/v1/exec", `{}`, 42, http.StatusUnauthorized}, + {"POST", "/v1/exec", `{}`, 0, http.StatusBadRequest}, + + {"POST", "/v1/signals", `{}`, -1, http.StatusUnauthorized}, + {"POST", "/v1/signals", `{}`, 42, http.StatusUnauthorized}, + {"POST", "/v1/signals", `{}`, 0, http.StatusBadRequest}, + + {"GET", "/v1/checks", ``, -1, http.StatusUnauthorized}, + {"GET", "/v1/checks", ``, 42, http.StatusOK}, + {"GET", "/v1/checks", ``, 0, http.StatusOK}, + } + + for _, test := range tests { + remoteAddr := "" + if test.uid >= 0 { + remoteAddr = fmt.Sprintf("pid=100;uid=%d;socket=;", test.uid) + } + requestURL, err := url.Parse("http://localhost" + test.path) + c.Assert(err, IsNil) + request := &http.Request{ + Method: test.method, + URL: requestURL, + Body: io.NopCloser(strings.NewReader(test.body)), + RemoteAddr: remoteAddr, + } + recorder := httptest.NewRecorder() + cmd := apiCmd(requestURL.Path) + cmd.ServeHTTP(recorder, request) + + response := recorder.Result() + if response.StatusCode != test.status { + // Log response body to make it easier to debug if the test fails. + c.Logf("%s %s uid=%d: expected %d, got %d; response body:\n%s", + test.method, test.path, test.uid, test.status, response.StatusCode, recorder.Body.String()) + } + c.Assert(response.StatusCode, Equals, test.status) + } +}