diff --git a/client.go b/client.go index 24bd7ff2..00917ea3 100644 --- a/client.go +++ b/client.go @@ -51,18 +51,34 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS // // It is safe to call Dialer's methods concurrently. type Dialer struct { + // The following custom dial functions can be set to establish + // connections to either the backend server or the proxy (if it + // exists). The scheme of the dialed entity (either backend or + // proxy) determines which custom dial function is selected: + // either NetDialTLSContext for HTTPS or NetDialContext/NetDial + // for HTTP. Since the "Proxy" function can determine the scheme + // dynamically, it can make sense to set multiple custom dial + // functions simultaneously. + // // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dialer DialContext is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If // NetDialContext is nil, NetDial is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If // NetDialTLSContext is nil, NetDialContext is used. // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and // TLSClientConfig is ignored. + // If "Proxy" field is also set, this function dials the proxy (and performs + // the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy + // dialing case the TLSClientConfig could still be necessary for TLS to the backend server. NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) // Proxy specifies a function to return a proxy for a given @@ -73,7 +89,7 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. - // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // If NetDialTLSContext is set, Dial assumes the TLS handshake // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config @@ -244,49 +260,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h defer cancel() } - var netDial netDialerFunc - switch { - case u.Scheme == "https" && d.NetDialTLSContext != nil: - netDial = d.NetDialTLSContext - case d.NetDialContext != nil: - netDial = d.NetDialContext - case d.NetDial != nil: - netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { - return d.NetDial(net, addr) - } - default: - netDial = (&net.Dialer{}).DialContext - } - - // If needed, wrap the dial function to set the connection deadline. - if deadline, ok := ctx.Deadline(); ok { - forwardDial := netDial - netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := forwardDial(ctx, network, addr) - if err != nil { - return nil, err - } - err = c.SetDeadline(deadline) - if err != nil { - c.Close() - return nil, err - } - return c, nil - } - } - - // If needed, wrap the dial function to connect through a proxy. + var proxyURL *url.URL if d.Proxy != nil { - proxyURL, err := d.Proxy(req) + proxyURL, err = d.Proxy(req) if err != nil { return nil, nil, err } - if proxyURL != nil { - netDial, err = proxyFromURL(proxyURL, netDial) - if err != nil { - return nil, nil, err - } - } + } + netDial, err := d.netDialFn(ctx, proxyURL, u) + if err != nil { + return nil, nil, err } hostPort, hostNoPort := hostPortNoPort(u) @@ -317,8 +300,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" && d.NetDialTLSContext == nil { - // If NetDialTLSContext is set, assume that the TLS handshake has already been done + // Do TLS handshake over established connection if a proxy exists. + if proxyURL != nil && u.Scheme == "https" { cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { @@ -415,6 +398,105 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return conn, resp, nil } +// Returns the dial function to establish the connection to either the backend +// server or the proxy (if it exists). If the dialed entity is HTTPS, then the +// returned dial function *also* performs the TLS handshake to the dialed entity. +// NOTE: If a proxy exists, it is possible for a second TLS handshake to be +// necessary over the established connection. +func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) { + var netDial netDialerFunc + if proxyURL != nil { + netDial = d.netDialFromURL(proxyURL) + } else { + netDial = d.netDialFromURL(backendURL) + } + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + netDial = netDialWithDeadline(netDial, deadline) + } + // Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth. + if proxyURL != nil { + return proxyFromURL(proxyURL, netDial) + } + return netDial, nil +} + +// Returns function to create the connection depending on the Dialer's +// custom dialing functions and the passed URL of entity connecting to. +func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc { + var netDial netDialerFunc + switch { + case d.NetDialContext != nil: + netDial = d.NetDialContext + case d.NetDial != nil: + netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { + return d.NetDial(net, addr) + } + default: + netDial = (&net.Dialer{}).DialContext + } + // If dialed entity is HTTPS, then either use custom TLS dialing function (if exists) + // or wrap the previously computed "netDial" to use TLS config for handshake. + if u.Scheme == "https" { + if d.NetDialTLSContext != nil { + netDial = d.NetDialTLSContext + } else { + netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u) + } + } + return netDial +} + +// Returns wrapped "netDial" function, performing TLS handshake after connecting. +func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc { + return func(ctx context.Context, unused, addr string) (net.Conn, error) { + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + // Creates TCP connection to addr using passed "netDial" function. + conn, err := netDial(ctx, "tcp", addr) + if err != nil { + return nil, err + } + cfg := cloneTLSConfig(tlsConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(conn, cfg) + // Do the TLS handshake using TLSConfig over the wrapped connection. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err = doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + if err != nil { + tlsConn.Close() + return nil, err + } + return tlsConn, nil + } +} + +// Returns wrapped "netDial" function, setting passed deadline. +func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := netDial(ctx, network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } +} + func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} diff --git a/client_proxy_server_test.go b/client_proxy_server_test.go new file mode 100644 index 00000000..4bdd136e --- /dev/null +++ b/client_proxy_server_test.go @@ -0,0 +1,938 @@ +// Copyright 2025 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" +) + +// These test cases use a websocket client (Dialer)/proxy/websocket server (Upgrader) +// to validate the cases where a proxy is an intermediary between a websocket client +// and server. The test cases usually 1) create a websocket server which echoes any +// data received back to the client, 2) a basic duplex streaming proxy, and 3) a +// websocket client which sends random data to the server through the proxy, +// validating any subsequent data received is the same as the data sent. The various +// permutations include the proxy and backend schemes (HTTP or HTTPS), as well as +// the custom dial functions (e.g NetDialContext, NetDial) set on the Dialer. + +const ( + subprotocolV1 = "subprotocol-version-1" + subprotocolV2 = "subprotocol-version-2" +) + +// Permutation 1 +// +// Backend: HTTP +// Proxy: HTTP +func TestHTTPProxyAndBackend(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Dial the websocket server through the proxy server. + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 2 +// +// Backend: HTTP +// Proxy: HTTP +// DialFn: NetDial (dials proxy) +func TestHTTPProxyWithNetDial(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Dial the websocket server through the proxy server. + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(context.Background(), network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 3 +// +// Backend: HTTP +// Proxy: HTTP +// DialFn: NetDialContext (dials proxy) +func TestHTTPProxyWithNetDialContext(t *testing.T) { + websocketTLS := false + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Dial the websocket server through the proxy server. + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 4 +// +// Backend: HTTPS +// Proxy: HTTP +// DialFn: NetDialTLSConfig (set but *ignored*) +// TLS Config: set (used for backend TLS) +func TestHTTPProxyWithHTTPSBackend(t *testing.T) { + websocketTLS := true + proxyTLS := false + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + var netDialTLSCalled atomic.Int64 + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + // This function should be ignored, because an HTTP proxy exists + // and the backend TLS handshake should use TLSClientConfig. + NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialTLSCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + // Used for the backend server TLS handshake. + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if numTLSDials := netDialTLSCalled.Load(); numTLSDials > 0 { + t.Errorf("NetDialTLS should have been ignored") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 5 +// +// Backend: HTTPS +// Proxy: HTTPS +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyAndBackend(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 6 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDial (used to dial proxy) +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyUsingNetDial(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(context.Background(), network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 7 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDialContext (used to dial proxy) +// TLS Config: set (used for both proxy and backend TLS) +func TestHTTPSProxyUsingNetDialContext(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + var netDialCalled atomic.Int64 + dialer := Dialer{ + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled.Add(1) + return (&net.Dialer{}).DialContext(ctx, network, addr) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), netDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 8 +// +// Backend: HTTPS +// Proxy: HTTPS +// DialFn: NetDialTLSContext (used for proxy TLS) +// TLS Config: set (used for backend TLS) +func TestHTTPSProxyUsingNetDialTLSContext(t *testing.T) { + websocketTLS := true + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Configure the proxy dialing function which dials the proxy and + // performs the TLS handshake. + var proxyDialCalled atomic.Int64 + proxyCerts := x509.NewCertPool() + proxyCerts.AppendCertsFromPEM(proxyServerCert) + proxyTLSConfig := &tls.Config{RootCAs: proxyCerts} + proxyDial := func(ctx context.Context, network, addr string) (net.Conn, error) { + proxyDialCalled.Add(1) + return tls.Dial(network, addr, proxyTLSConfig) + } + // Configure the backend webscocket TLS configuration (handshake occurs + // over the previously created proxy connection). + websocketCerts := x509.NewCertPool() + websocketCerts.AppendCertsFromPEM(websocketServerCert) + websocketTLSConfig := &tls.Config{RootCAs: websocketCerts} + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + // Dial and TLS handshake function to proxy. + NetDialTLSContext: proxyDial, + // Used for second TLS handshake to backend server over previously + // established proxy connection. + TLSClientConfig: websocketTLSConfig, + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), proxyDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 9 +// +// Backend: HTTP +// Proxy: HTTPS +// TLS Config: set (used for proxy TLS) +func TestHTTPSProxyHTTPBackend(t *testing.T) { + websocketTLS := false + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +// Permutation 10 +// +// Backend: HTTP +// Proxy: HTTPS +// DialFn: NetDialTLSContext (used for proxy TLS) +// TLS Config: set (ignored) +func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend(t *testing.T) { + websocketTLS := false + proxyTLS := true + // Start the websocket server, which echoes data back to sender. + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Start the proxy server. + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + var proxyDialCalled atomic.Int64 + dialer := Dialer{ + NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + proxyDialCalled.Add(1) + return tls.Dial(network, addr, tlsConfig(websocketTLS, proxyTLS)) + }, + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: &tls.Config{}, // Misconfigured, but ignored. + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Send, receive, and validate random data over websocket connection. + sendReceiveData(t, wsClient) + if e, a := int64(1), proxyDialCalled.Load(); e != a { + t.Errorf("netDial not called") + } + // Validate the proxy server was called. + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy not called") + } +} + +func TestTLSValidationErrors(t *testing.T) { + // Both websocket and proxy servers are started with TLS. + websocketTLS := true + proxyTLS := true + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Dialer without proxy CA cert fails TLS verification. + tlsError := "tls: failed to verify certificate" + dialer := Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(true, false), + Subprotocols: []string{subprotocolV1}, + } + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Errorf("expected proxy TLS verification error did not arrive") + } else if !strings.Contains(err.Error(), tlsError) { + t.Errorf("expected proxy TLS error (%s), got (%s)", err.Error(), tlsError) + } + // Validate the proxy handler was *NOT* called (because proxy + // server TLS validation failed). + if e, a := int64(0), proxyServer.NumCalls(); e != a { + t.Errorf("proxy should not have been called") + } + // Dialer without websocket CA cert fails TLS verification. + dialer = Dialer{ + Proxy: http.ProxyURL(proxyServerURL), + TLSClientConfig: tlsConfig(false, true), + Subprotocols: []string{subprotocolV1}, + } + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Errorf("expected websocket TLS verification error did not arrive") + } else if !strings.Contains(err.Error(), tlsError) { + t.Errorf("expected websocket TLS error (%s), got (%s)", err.Error(), tlsError) + } + // Validate the proxy server *was* called (but subsequent + // websocket server failed TLS validation). + if e, a := int64(1), proxyServer.NumCalls(); e != a { + t.Errorf("proxy have been called") + } +} + +func TestProxyFnErrorIsPropagated(t *testing.T) { + websocketServer, websocketURL, err := newWebsocketServer(false) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + // Create a Dialer where Proxy function always returns an error. + proxyURLError := errors.New("proxy URL generation error") + dialer := Dialer{ + Proxy: func(r *http.Request) (*url.URL, error) { + return nil, proxyURLError + }, + Subprotocols: []string{subprotocolV1}, + } + // Proxy URL generation error should halt request and be propagated. + _, _, err = dialer.Dial(websocketURL.String(), nil) + if err == nil { + t.Fatalf("expected websocket dial error, received none") + } else if !errors.Is(proxyURLError, err) { + t.Fatalf("expected error (%s), got (%s)", proxyURLError, err) + } +} + +func TestProxyFnNilMeansNoProxy(t *testing.T) { + // Both websocket and proxy servers are started. + websocketTLS := false + proxyTLS := false + websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) + defer websocketServer.Close() + if err != nil { + t.Fatalf("error starting websocket server: %v", err) + } + proxyServer, _, err := newProxyServer(proxyTLS) + defer proxyServer.Close() + if err != nil { + t.Fatalf("error starting proxy server: %v", err) + } + defer proxyServer.Close() + // Dialer created with Proxy URL generation function returning nil + // proxy URL, which continues with backend server connection without + // proxying. + dialer := Dialer{ + Proxy: func(r *http.Request) (*url.URL, error) { + return nil, nil + }, + Subprotocols: []string{subprotocolV1}, + } + wsClient, _, err := dialer.Dial(websocketURL.String(), nil) + if err != nil { + t.Fatalf("websocket dial error: %v", err) + } + // Create, send, receive, validate the random data. Backend + // server connection is successful (without a proxy). + randomData := make([]byte, randomDataSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + err = wsClient.WriteMessage(BinaryMessage, randomData) + if err != nil { + t.Errorf("websocket write error: %v", err) + } + _, received, err := wsClient.ReadMessage() + if !bytes.Equal(randomData, received) { + t.Errorf("unexpected data received: %d bytes sent, %d bytes received", + len(received), len(randomData)) + } + // Validate the proxy handler was *NOT* called (because proxy + // URL generation returned nil). + if e, a := int64(0), proxyServer.NumCalls(); e != a { + t.Errorf("proxy should not have been called") + } +} + +// websocketEchoHandler upgrades the connection associated with the request, and +// echoes binary messages read off the websocket connection back to the client. +var websocketEchoHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + upgrader := Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Accepting all requests + }, + Subprotocols: []string{ + subprotocolV1, + subprotocolV2, + }, + } + wsConn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + defer wsConn.Close() + for { + writer, err := wsConn.NextWriter(BinaryMessage) + if err != nil { + break + } + messageType, reader, err := wsConn.NextReader() + if err != nil { + break + } + if messageType != BinaryMessage { + http.Error(w, "websocket reader not binary message type", + http.StatusInternalServerError) + } + _, err = io.Copy(writer, reader) + if err != nil { + http.Error(w, "websocket server io copy error", + http.StatusInternalServerError) + } + } +}) + +type CounterCloser interface { + Counter + Closer +} + +type Counter interface { + Increment() + NumCalls() int64 +} + +type Closer interface { + Close() +} + +type testServer struct { + server *httptest.Server + numCalls atomic.Int64 +} + +func (ts *testServer) NumCalls() int64 { + return ts.numCalls.Load() +} + +func (ts *testServer) Increment() { + ts.numCalls.Add(1) +} + +func (ts *testServer) Close() { + if ts.server != nil { + ts.server.Close() + } +} + +// Returns a test backend websocket server as well as the URL pointing +// to the server, or an error if one occurred. Sets up a TLS endpoint +// on the server if the passed "tlsServer" is true. +// func newWebsocketServer(tlsServer bool) (*httptest.Server, *url.URL, error) { +func newWebsocketServer(tlsServer bool) (Closer, *url.URL, error) { + // Start the websocket server, which echoes data back to sender. + websocketServer := httptest.NewUnstartedServer(websocketEchoHandler) + if tlsServer { + websocketKeyPair, err := tls.X509KeyPair(websocketServerCert, websocketServerKey) + if err != nil { + return nil, nil, err + } + websocketServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{websocketKeyPair}, + } + websocketServer.StartTLS() + } else { + websocketServer.Start() + } + websocketURL, err := url.Parse(websocketServer.URL) + if err != nil { + return nil, nil, err + } + if tlsServer { + websocketURL.Scheme = "wss" + } else { + websocketURL.Scheme = "ws" + } + return websocketServer, websocketURL, nil +} + +// proxyHandler creates a full duplex streaming connection between the client +// (hijacking the http request connection), and an "upstream" dialed connection +// to the "Host". Creates two goroutines to copy between connections in each direction. +var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Validate the CONNECT method. + if req.Method != http.MethodConnect { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + // Dial upstream server. + upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer upstream.Close() + // Return 200 OK to client. + w.WriteHeader(http.StatusOK) + // Hijack client connection. + client, _, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer client.Close() + // Create duplex streaming between client and upstream connections. + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, client) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(client, upstream) + done <- struct{}{} + }() + <-done +}) + +// Returns a new test HTTP server, as well as the URL to that server, or +// an error if one occurred. numProxyCalls keeps track of the number of +// times the proxy handler was called with this server. +func newProxyServer(tlsServer bool) (CounterCloser, *url.URL, error) { + // Start the proxy server, keeping track of how many times the handler is called. + ts := &testServer{} + proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ts.Increment() + proxyHandler.ServeHTTP(w, req) + })) + if tlsServer { + proxyKeyPair, err := tls.X509KeyPair(proxyServerCert, proxyServerKey) + if err != nil { + return nil, nil, err + } + proxyServer.TLS = &tls.Config{ + Certificates: []tls.Certificate{proxyKeyPair}, + } + proxyServer.StartTLS() + } else { + proxyServer.Start() + } + proxyURL, err := url.Parse(proxyServer.URL) + if err != nil { + return nil, nil, err + } + return ts, proxyURL, nil +} + +// Returns the TLS config with the RootCAs cert pool set. If +// neither websocket nor proxy server uses TLS, returns nil. +func tlsConfig(websocketTLS bool, proxyTLS bool) *tls.Config { + if !websocketTLS && !proxyTLS { + return nil + } + certPool := x509.NewCertPool() + tlsConfig := &tls.Config{ + RootCAs: certPool, + } + if websocketTLS { + tlsConfig.RootCAs.AppendCertsFromPEM(websocketServerCert) + } + if proxyTLS { + tlsConfig.RootCAs.AppendCertsFromPEM(proxyServerCert) + } + return tlsConfig +} + +// Sends, receives, and validates random data sent and received +// over the passed websocket connection. +const randomDataSize = 128 * 1024 + +func sendReceiveData(t *testing.T, wsConn *Conn) { + // Create the random data. + randomData := make([]byte, randomDataSize) + if _, err := rand.Read(randomData); err != nil { + t.Errorf("unexpected error reading random data: %v", err) + } + // Send the random data. + err := wsConn.WriteMessage(BinaryMessage, randomData) + if err != nil { + t.Errorf("websocket write error: %v", err) + } + // Read from the websocket connection, and validate the + // read data is the same as the previously sent data. + _, received, err := wsConn.ReadMessage() + if !bytes.Equal(randomData, received) { + t.Errorf("unexpected data received: %d bytes sent, %d bytes received", + len(received), len(randomData)) + } +} + +// proxyServerCert was generated from crypto/tls/generate_cert.go with the following command: +// +// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +// +// proxyServerCert is a self-signed. +var proxyServerCert = []byte(`-----BEGIN CERTIFICATE----- +MIIDGTCCAgGgAwIBAgIRALL5AZcefF4kkYV1SEG6YrMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBALQ/FHcyVwdFHxARbbD2KBtDUT7Eni+8ioNdjtGcmtXqBv45EC1C +JOqqGJTroFGJ6Q9kQIZ9FqH5IJR2fOOJD9kOTueG4Vt1JY1rj1Kbpjefu8XleZ5L +SBwIWVnN/lEsEbuKmj7N2gLt5AH3zMZiBI1mg1u9Z5ZZHYbCiTpBrwsq6cTlvR9g +dyo1YkM5hRESCzsrL0aUByoo0qRMD8ZsgANJwgsiO0/M6idbxDwv1BnGwGmRYvOE +Hxpy3v0Jg7GJYrvnpnifJTs4nw91N5X9pXxR7FFzi/6HTYDWRljvTb0w6XciKYAz +bWZ0+cJr5F7wB7ovlbm7HrQIR7z7EIIu2d8CAwEAAaNoMGYwDgYDVR0PAQH/BAQD +AgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0R +BCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZI +hvcNAQELBQADggEBAFPPWopNEJtIA2VFAQcqN6uJK+JVFOnjGRoCrM6Xgzdm0wxY +XCGjsxY5dl+V7KzdGqu858rCaq5osEBqypBpYAnS9C38VyCDA1vPS1PsN8SYv48z +DyBwj+7R2qar0ADBhnhWxvYO9M72lN/wuCqFKYMeFSnJdQLv3AsrrHe9lYqOa36s +8wxSwVTFTYXBzljPEnSaaJMPqFD8JXaZK1ryJPkO5OsCNQNGtatNiWAf3DcmwHAT +MGYMzP0u4nw47aRz9shB8w+taPKHx2BVwE1m/yp3nHVioOjXqA1fwRQVGclCJSH1 +D2iq3hWVHRENgjTjANBPICLo9AZ4JfN6PH19mnU= +-----END CERTIFICATE-----`) + +// proxyServerKey is the private key for proxyServerCert. +var proxyServerKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAtD8UdzJXB0UfEBFtsPYoG0NRPsSeL7yKg12O0Zya1eoG/jkQ +LUIk6qoYlOugUYnpD2RAhn0WofkglHZ844kP2Q5O54bhW3UljWuPUpumN5+7xeV5 +nktIHAhZWc3+USwRu4qaPs3aAu3kAffMxmIEjWaDW71nllkdhsKJOkGvCyrpxOW9 +H2B3KjViQzmFERILOysvRpQHKijSpEwPxmyAA0nCCyI7T8zqJ1vEPC/UGcbAaZFi +84QfGnLe/QmDsYliu+emeJ8lOzifD3U3lf2lfFHsUXOL/odNgNZGWO9NvTDpdyIp +gDNtZnT5wmvkXvAHui+VubsetAhHvPsQgi7Z3wIDAQABAoIBAGmw93IxjYCQ0ncc +kSKMJNZfsdtJdaxuNRZ0nNNirhQzR2h403iGaZlEpmdkhzxozsWcto1l+gh+SdFk +bTUK4MUZM8FlgO2dEqkLYh5BcMT7ICMZvSfJ4v21E5eqR68XVUqQKoQbNvQyxFk3 +EddeEGdNrkb0GDK8DKlBlzAW5ep4gjG85wSTjR+J+muUv3R0BgLBFSuQnIDM/IMB +LWqsja/QbtB7yppe7jL5u8UCFdZG8BBKT9fcvFIu5PRLO3MO0uOI7LTc8+W1Xm23 +uv+j3SY0+v+6POjK0UlJFFi/wkSPTFIfrQO1qFBkTDQHhQ6q/7GnILYYOiGbIRg2 +NNuP52ECgYEAzXEoy50wSYh8xfFaBuxbm3ruuG2W49jgop7ZfoFrPWwOQKAZS441 +VIwV4+e5IcA6KkuYbtGSdTYqK1SMkgnUyD/VevwAqH5TJoEIGu0pDuKGwVuwqioZ +frCIAV5GllKyUJ55VZNbRr2vY2fCsWbaCSCHETn6C16DNuTCe5C0JBECgYEA4JqY +5GpNbMG8fOt4H7hU0Fbm2yd6SHJcQ3/9iimef7xG6ajxsYrIhg1ft+3IPHMjVI0+ +9brwHDnWg4bOOx/VO4VJBt6Dm/F33bndnZRkuIjfSNpLM51P+EnRdaFVHOJHwKqx +uF69kihifCAG7YATgCveeXImzBUSyZUz9UrETu8CgYARNBimdFNG1RcdvEg9rC0/ +p9u1tfecvNySwZqU7WF9kz7eSonTueTdX521qAHowaAdSpdJMGODTTXaywm6cPhQ +jIfj9JZZhbqQzt1O4+08Qdvm9TamCUB5S28YLjza+bHU7nBaqixKkDfPqzCyilpX +yVGGL8SwjwmN3zop/sQXAQKBgC0JMsESQ6YcDsRpnrOVjYQc+LtW5iEitTdfsaID +iGGKihmOI7B66IxgoCHMTws39wycKdSyADVYr5e97xpR3rrJlgQHmBIrz+Iow7Q2 +LiAGaec8xjl6QK/DdXmFuQBKqyKJ14rljFODP4QuE9WJid94bGqjpf3j99ltznZP +4J8HAoGAJb4eb4lu4UGwifDzqfAPzLGCoi0fE1/hSx34lfuLcc1G+LEu9YDKoOVJ +9suOh0b5K/bfEy9KrVMBBriduvdaERSD8S3pkIQaitIz0B029AbE4FLFf9lKQpP2 +KR8NJEkK99Vh/tew6jAMll70xFrE7aF8VLXJVE7w4sQzuvHxl9Q= +-----END RSA PRIVATE KEY----- +`) + +// websocketServerCert is self-signed. +var websocketServerCert = []byte(`-----BEGIN CERTIFICATE----- +MIIDOTCCAiGgAwIBAgIQYSN1VY/favsLUo+B7gJ5tTANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApBlintjkL1fO1Sk2pzNvl862CtTwU7/Jy6EZqWzI17wEbPn4sbSD +bHhfDlPl2nmw3hVkc6LNK+eqzm2GX/ai4tgMiaH7kyyNit1K3g7y7GISMf9poWIa +POJhid2wmhKHbEtHECSdQ5c/jEN1UVzB4go5LO7MEEVo9kyQ+yBqS6gISyFmfaT4 +qOsPJBir33bBpptSend1JSXaRTXqRa1p+oudw2ILa4U7KfuKK3emp21m5/HYAuSf +CV4WqqDoDiBPMpsQ0kPEPugWZKFeF3qanmqFFvptYx+zJbOznWYY2D3idWsvcg6q +VLPEB19oXaVBV0HXPFtObm5m1jCpl8FI1wIDAQABo4GIMIGFMA4GA1UdDwEB/wQE +AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud +DgQWBBQcSkjqA9rgos1daegNj49BpRCA0jAuBgNVHREEJzAlggtleGFtcGxlLmNv +bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAnk9i +9rogNTi9B1pn+Fbk3WALKdEjv/uyePsTnwdyvswVbeYbQweU9TrhYT2+eXbMA5kY +7TaQm46idRqxCKMgc3Ip3DADJdm8cJX9p2ExU4fKdkPc1KD/J+4QHHx1W2Ml5S2o +foOo6j1F0UdZP/rBj0UumEZp32qW+4DhVV/QQjUB8J0gaDC7yZBMdyMIeClR0RqE +YfZdCJbQHqtTwBXN+imQUHPGmksYkRDpFRvw/4crpcMIE04mVVd99nOpFCQnK61t +9US1y17VW1lYpkqlCS+rkcAtor4Z5naSf9/oLGCxEAwyW0pwHGO6MXtMxvB/JD20 +hJdlz1I7wlSfF4MiRQ== +-----END CERTIFICATE-----`) + +// websocketServerKey is the private key for websocketServerCert. +var websocketServerKey = []byte(`-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCkGWKe2OQvV87V +KTanM2+XzrYK1PBTv8nLoRmpbMjXvARs+fixtINseF8OU+XaebDeFWRzos0r56rO +bYZf9qLi2AyJofuTLI2K3UreDvLsYhIx/2mhYho84mGJ3bCaEodsS0cQJJ1Dlz+M +Q3VRXMHiCjks7swQRWj2TJD7IGpLqAhLIWZ9pPio6w8kGKvfdsGmm1J6d3UlJdpF +NepFrWn6i53DYgtrhTsp+4ord6anbWbn8dgC5J8JXhaqoOgOIE8ymxDSQ8Q+6BZk +oV4XepqeaoUW+m1jH7Mls7OdZhjYPeJ1ay9yDqpUs8QHX2hdpUFXQdc8W05ubmbW +MKmXwUjXAgMBAAECggEAE6BkTDgH//rnkP/Ej/Y17Zkv6qxnMLe/4evwZB7PsrBu +cxOUAWUOpvA1UO215bh87+2XvcDbUISnyC1kpKDyAGGeC5llER2DXE11VokWgtvZ +Q0OXavw5w83A+WVGFFdiUmXP0l10CxEm7OwQjFz6D21GQ1qC65tG9NZZghTxbFTe +iZKqgWqyHsaAWLOuDQbj1FTEBMFrY8f9RbclSh0luPZnzGc4BVI/t34jKPZBpH2N +NCkr8aB7MMHGhrNZFHAu/KAvq8UBrDTX+O8ERMwcwQWB4nne2+GOTN0MdcAUc72i +GryzIa8TgO+TpQOYoZ4NPnzFrsa+m3G2Tug3vbt62QKBgQDOPfM4/5/x/h/ggxQn +aRvEOC+8ldeqEOS1VTGiuDKJMWXrNkG+d+AsxfNP4k0QVNrpEAZSYcf0gnS9Odcl +luEsi/yPZDDnPg/cS+Z3336VKsggly7BWFs1Ct/9I+ZfSCl88TkVpIfeCBC34XEb +0mFUq/RdLqXj/mVLbBfr+H8cEwKBgQDLsJUm8lkWFAPJ8UMto8xeUMGk44VukYwx ++oI6KhplFntiI0C1Dd9wrxyCjySlJcc0NFt6IPN84d7pI9LQSbiKXQ1jMvsBzd4G +EMtG8SHpIY/mMU+KzWLHYVFS0FA4PvXXvPRNLOXas7hbALZdLshVKd7aDlkQAb5C +KWFHeIFwrQKBgA8r5Xl67HQrwoKMge4IQF+l1nUj/LJo/boNI1KaBDWtaZbs7dcq +EFaa1TQ6LHsYEuZ0JFLpGIF3G0lUOOxt9fCF97VApIxON3J4LuMAkNo+RGyJUoos +isETJLkFbAv0TgD/6bga21fM9hXgwqZOSpSk9ZvpM5DbBO6QbA4SwJ77AoGAX7h1 +/z14XAW/2hDE7xfAnLn6plA9jj5b0cjVlhvfF44/IVlLuUnxrPS9wyUdpXZhbMkG +DBicFB3ZMVqiYTuju3ILLojwqGJkahlOTeJXe0VIaHbX2HS4bNXw76fxat07jsy/ +Sd1Fj0dR5YIqMRQhFNR+Y57Gf90x2cm0a2/X9GkCgYANawYx9bNfcX0HMVG7vktK +6/80omnoBM0JUxA+V7DxS8kr9Cj2Y/kcS+VHb4yyoSkDgnsSdnCr1ZTctcj828MJ +8AUwskAtEjPkHRXEgRRnEl2oJGD1TT5iwBNnuPAQDXwzkGCRYBnlfZNbILbOoSUz +m+VDcqT5XzcRADa/TLlEXA== +-----END PRIVATE KEY----- +`) diff --git a/proxy.go b/proxy.go index b4683b9f..d716a058 100644 --- a/proxy.go +++ b/proxy.go @@ -29,7 +29,7 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( } func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { - if proxyURL.Scheme == "http" { + if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil } dialer, err := proxy.FromURL(proxyURL, forwardDial) @@ -64,7 +64,6 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add connectHeader.Set("Proxy-Authorization", "Basic "+credential) } } - connectReq := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Opaque: addr},