From 0319061b2543067460267056909ed7ee2d6437f4 Mon Sep 17 00:00:00 2001 From: Igor Shishkin Date: Wed, 28 Aug 2024 08:36:04 +0300 Subject: [PATCH] Allow to preserve scheme on redirect to object (#191) Closes #187 --------- Signed-off-by: Igor Shishkin --- cmd/publisher/main.go | 10 +++-- publisher/presenter/html/handlers.go | 47 ++++++++++++++++++----- publisher/presenter/html/handlers_test.go | 44 ++++++++++++++++++++- 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/cmd/publisher/main.go b/cmd/publisher/main.go index 21ed601..ae6daa2 100644 --- a/cmd/publisher/main.go +++ b/cmd/publisher/main.go @@ -52,12 +52,14 @@ type config struct { BLOBS3DisableSSL bool `envconfig:"BLOB_S3_DISABLE_SSL" default:"false"` BLOBS3ForcePathStyle bool `envconfig:"BLOB_S3_FORCE_PATH_STYLE" default:"true"` + BLOBS3PreserveSchemeOnRedirect bool `envconfig:"BLOB_S3_PRESERVE_SCHEME_ON_REDIRECT" default:"true"` + HTMLTemplateDir string `envconfig:"HTML_TEMPLATE_DIR" required:"true"` StaticDir string `envconfig:"STATIC_DIR" required:"true"` - VersionsPerPage uint64 `envconfig:"VERSIONS_PER_PAGE" default:"50"` - ObjectsPerPage uint64 `envconfig:"OBJECTS_PER_PAGE" default:"50"` - ContainersPerPage uint64 `envconfig:"CONTAINERS_PER_PAGE" default:"50"` + VersionsPerPage uint64 `envconfig:"VERSIONS_PER_PAGE" default:"50"` + ObjectsPerPage uint64 `envconfig:"OBJECTS_PER_PAGE" default:"50"` + ContainersPerPage uint64 `envconfig:"CONTAINERS_PER_PAGE" default:"50"` } func main() { @@ -122,7 +124,7 @@ func main() { publisherSvc := service.NewPublisher(repo, blobRepo, cfg.VersionsPerPage, cfg.ObjectsPerPage, cfg.ContainersPerPage) - p := htmlPresenter.New(publisherSvc, cfg.HTMLTemplateDir, cfg.StaticDir) + p := htmlPresenter.New(publisherSvc, cfg.HTMLTemplateDir, cfg.StaticDir, cfg.BLOBS3PreserveSchemeOnRedirect) p.Register(e) g.Go(func() error { diff --git a/publisher/presenter/html/handlers.go b/publisher/presenter/html/handlers.go index 213b077..57e54f0 100644 --- a/publisher/presenter/html/handlers.go +++ b/publisher/presenter/html/handlers.go @@ -29,16 +29,18 @@ type Handlers interface { } type handlers struct { - svc service.Publisher - staticDir string - templateDir string + svc service.Publisher + staticDir string + templateDir string + preserveSchemeOnRedirect bool } -func New(svc service.Publisher, templateDir, staticDir string) Handlers { +func New(svc service.Publisher, templateDir, staticDir string, preserveSchemeOnRedirect bool) Handlers { return &handlers{ - svc: svc, - staticDir: staticDir, - templateDir: templateDir, + svc: svc, + staticDir: staticDir, + templateDir: templateDir, + preserveSchemeOnRedirect: preserveSchemeOnRedirect, } } @@ -195,7 +197,7 @@ func (h *handlers) GetObject(c echo.Context) error { return err } - url, err := h.svc.GetObjectURL(c.Request().Context(), namespace, container, version, key) + link, err := h.svc.GetObjectURL(c.Request().Context(), namespace, container, version, key) if err != nil { if err == service.ErrNotFound { return c.Render(http.StatusNotFound, notFoundTemplateFilename, nil) @@ -203,7 +205,34 @@ func (h *handlers) GetObject(c echo.Context) error { return err } - return c.Redirect(http.StatusFound, url) + xForwardedScheme := c.Request().Header.Get("X-Forwarded-Scheme") + xScheme := c.Request().Header.Get("X-Scheme") + + allowedValues := map[string]struct{}{ + "http": {}, + "https": {}, + } + + _, xForwardedSchemeOk := allowedValues[xForwardedScheme] + _, xSchemeOk := allowedValues[xScheme] + + if h.preserveSchemeOnRedirect && (xForwardedSchemeOk || xSchemeOk) { + scheme := xForwardedScheme + if !xForwardedSchemeOk { + scheme = xScheme + } + + u, err := url.Parse(link) + if err != nil { + return c.Blob(http.StatusInternalServerError, "text/plain", []byte("error parsing url")) + } + + u.Scheme = scheme + + link = u.String() + } + + return c.Redirect(http.StatusFound, link) } func (h *handlers) ErrorHandler(err error, c echo.Context) { diff --git a/publisher/presenter/html/handlers_test.go b/publisher/presenter/html/handlers_test.go index 3348484..cc8feea 100644 --- a/publisher/presenter/html/handlers_test.go +++ b/publisher/presenter/html/handlers_test.go @@ -61,6 +61,48 @@ func (s *handlersTestSuite) TestGetObject() { s.Require().Equal("https://example.com/some-addr", v) } +func (s *handlersTestSuite) TestGetObjectSchemeMismatchXForwardedScheme() { + s.serviceMock.On("GetObjectURL", defaultNamespace, "test-container-1", "20241011121314", "test-dir/filename.txt").Return("https://example.com/some-addr", nil).Once() + + req, err := http.NewRequest(http.MethodGet, s.srv.URL+"/default/test-container-1/20241011121314/test-dir/filename.txt", nil) + s.Require().NoError(err) + + req.Header.Set("X-Forwarded-Scheme", "http") + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Do(req) + s.Require().NoError(err) + + v := resp.Header.Get("Location") + s.Require().Equal("http://example.com/some-addr", v) +} + +func (s *handlersTestSuite) TestGetObjectSchemeMismatchXScheme() { + s.serviceMock.On("GetObjectURL", defaultNamespace, "test-container-1", "20241011121314", "test-dir/filename.txt").Return("https://example.com/some-addr", nil).Once() + + req, err := http.NewRequest(http.MethodGet, s.srv.URL+"/default/test-container-1/20241011121314/test-dir/filename.txt", nil) + s.Require().NoError(err) + + req.Header.Set("X-Scheme", "http") + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Do(req) + s.Require().NoError(err) + + v := resp.Header.Get("Location") + s.Require().Equal("http://example.com/some-addr", v) +} + func (s *handlersTestSuite) TestErrNotFound() { s.serviceMock.On("ListPublishedVersionsByPage", defaultNamespace, "test-container-1", uint64(1)).Return(uint64(100), []models.Version(nil), service.ErrNotFound).Once() s.compareHTMLResponse(s.srv.URL+"/default/test-container-1/", "testdata/404.html.sample") @@ -116,7 +158,7 @@ func (s *handlersTestSuite) SetupTest() { s.serviceMock = service.NewMock() - s.handlers = New(s.serviceMock, "templates", "static") + s.handlers = New(s.serviceMock, "templates", "static", true) s.handlers.Register(e) s.srv = httptest.NewServer(e)