Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Implements missing HTTPS proxy functionality #978

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 124 additions & 42 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,34 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// The following custom dial functions can be set to establish
// connections to either the backend server or the proxy (if it
// exists). The scheme of the dialed entity (either backend or
// proxy) determines which custom dial function is selected:
// either NetDialTLSContext for HTTPS or NetDialContext/NetDial
// for HTTP. Since the "Proxy" function can determine the scheme
// dynamically, it can make sense to set multiple custom dial
// functions simultaneously.
//
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dialer DialContext is used.
// If "Proxy" field is also set, this function dials the proxy--not
// the backend server.
NetDial func(network, addr string) (net.Conn, error)

// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, NetDial is used.
// If "Proxy" field is also set, this function dials the proxy--not
// the backend server.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)

// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
// If "Proxy" field is also set, this function dials the proxy (and performs
// the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy
// dialing case the TLSClientConfig could still be necessary for TLS to the backend server.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)

// Proxy specifies a function to return a proxy for a given
Expand All @@ -73,7 +89,7 @@ type Dialer struct {

// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// If NetDialTLSContext is set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config

Expand Down Expand Up @@ -244,49 +260,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}

var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
netDial = (&net.Dialer{}).DialContext
}

// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := forwardDial(ctx, network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}

// If needed, wrap the dial function to connect through a proxy.
var proxyURL *url.URL
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
proxyURL, err = d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
}
}
netDial, err := d.netDialFn(ctx, proxyURL, u)
if err != nil {
return nil, nil, err
}

hostPort, hostNoPort := hostPortNoPort(u)
Expand Down Expand Up @@ -317,8 +300,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()

if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
// Do TLS handshake over established connection if a proxy exists.
if proxyURL != nil && u.Scheme == "https" {

cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
Expand Down Expand Up @@ -415,6 +398,105 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}

// Returns the dial function to establish the connection to either the backend
// server or the proxy (if it exists). If the dialed entity is HTTPS, then the
// returned dial function *also* performs the TLS handshake to the dialed entity.
// NOTE: If a proxy exists, it is possible for a second TLS handshake to be
// necessary over the established connection.
func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) {
var netDial netDialerFunc
if proxyURL != nil {
netDial = d.netDialFromURL(proxyURL)
} else {
netDial = d.netDialFromURL(backendURL)
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
netDial = netDialWithDeadline(netDial, deadline)
}
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
if proxyURL != nil {
return proxyFromURL(proxyURL, netDial)
}
return netDial, nil
}

// Returns function to create the connection depending on the Dialer's
// custom dialing functions and the passed URL of entity connecting to.
func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc {
var netDial netDialerFunc
switch {
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
netDial = (&net.Dialer{}).DialContext
}
// If dialed entity is HTTPS, then either use custom TLS dialing function (if exists)
// or wrap the previously computed "netDial" to use TLS config for handshake.
if u.Scheme == "https" {
if d.NetDialTLSContext != nil {
netDial = d.NetDialTLSContext
} else {
netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u)
}
}
return netDial
}

// Returns wrapped "netDial" function, performing TLS handshake after connecting.
func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc {
return func(ctx context.Context, unused, addr string) (net.Conn, error) {
hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
}
// Creates TCP connection to addr using passed "netDial" function.
conn, err := netDial(ctx, "tcp", addr)
if err != nil {
return nil, err
}
cfg := cloneTLSConfig(tlsConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(conn, cfg)
// Do the TLS handshake using TLSConfig over the wrapped connection.
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err = doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
}
}

// Returns wrapped "netDial" function, setting passed deadline.
func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := netDial(ctx, network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}

func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
Expand Down
Loading