diff --git a/integration/app_integration_test.go b/integration/app_integration_test.go index dae4286deb87a..4374539ebc6b1 100644 --- a/integration/app_integration_test.go +++ b/integration/app_integration_test.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "encoding/json" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -189,6 +190,45 @@ func TestAppAccessClientCert(t *testing.T) { } } +// TestAppAccessFlush makes sure that application access periodically flushes +// buffered data to the response. +func TestAppAccessFlush(t *testing.T) { + pack := setup(t) + + req, err := http.NewRequest("GET", pack.assembleRootProxyURL("/"), nil) + require.NoError(t, err) + + cookie := pack.createAppSession(t, pack.flushAppPublicAddr, pack.flushAppClusterName) + req.AddCookie(&http.Cookie{ + Name: app.CookieName, + Value: cookie, + }) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // The "flush server" will send 2 messages, "hello" and "world", with a + // 500ms delay between them. They should arrive as 2 different frames + // due to the periodic flushing. + frames := []string{"hello", "world"} + for _, frame := range frames { + buffer := make([]byte, 1024) + n, err := resp.Body.Read(buffer) + if err != nil { + require.ErrorIs(t, err, io.EOF) + } + require.Equal(t, frame, strings.TrimSpace(string(buffer[:n]))) + } +} + // TestAppAccessForwardModes ensures that requests are forwarded to applications // even when the cluster is in proxy recording mode. func TestAppAccessForwardModes(t *testing.T) { @@ -585,6 +625,10 @@ type pack struct { headerAppName string headerAppPublicAddr string headerAppClusterName string + + flushAppName string + flushAppPublicAddr string + flushAppClusterName string } type appTestOptions struct { @@ -648,6 +692,10 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { headerAppName: "app-04", headerAppPublicAddr: "app-04.example.com", headerAppClusterName: "example.com", + + flushAppName: "app-05", + flushAppPublicAddr: "app-05.example.com", + flushAppClusterName: "example.com", } // Start a few different HTTP server that will be acting like a proxied application. @@ -693,6 +741,25 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { } })) t.Cleanup(headerServer.Close) + flushServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.(http.Hijacker) + conn, _, err := h.Hijack() + require.NoError(t, err) + defer conn.Close() + data := "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "05\r\n" + + "hello\r\n" + fmt.Fprint(conn, data) + time.Sleep(500 * time.Millisecond) + data = "05\r\n" + + "world\r\n" + + "0\r\n" + + "\r\n" + fmt.Fprint(conn, data) + })) + t.Cleanup(flushServer.Close) p.jwtAppURI = jwtServer.URL @@ -808,6 +875,11 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { URI: headerServer.URL, PublicAddr: p.headerAppPublicAddr, }, + { + Name: p.flushAppName, + URI: flushServer.URL, + PublicAddr: p.flushAppPublicAddr, + }, }, opts.extraRootApps...) p.rootAppServer, err = p.rootCluster.StartApp(raConf) require.NoError(t, err) diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index 54dcf221bac7d..47b658359fdaa 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -87,6 +87,7 @@ func (s *Server) newSession(ctx context.Context, identity *tlsca.Identity, app t return nil, trace.Wrap(err) } fwd, err := forward.New( + forward.FlushInterval(100*time.Millisecond), forward.RoundTripper(transport), forward.Logger(logrus.StandardLogger()), forward.WebsocketRewriter(transport.ws), diff --git a/lib/web/app/session.go b/lib/web/app/session.go index 8ed8eab4f41fa..72782f71ab0a7 100644 --- a/lib/web/app/session.go +++ b/lib/web/app/session.go @@ -81,6 +81,7 @@ func (h *Handler) newSession(ctx context.Context, ws types.WebSession) (*session return nil, trace.Wrap(err) } fwd, err := forward.New( + forward.FlushInterval(100*time.Millisecond), forward.RoundTripper(transport), forward.Logger(h.log), forward.PassHostHeader(true),