Skip to content

Commit

Permalink
Implements HTTPS proxy functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
seans3 committed Mar 5, 2025
1 parent 5e00238 commit 0b0f26a
Show file tree
Hide file tree
Showing 3 changed files with 1,137 additions and 45 deletions.
165 changes: 122 additions & 43 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,34 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// The following custom dial functions can be set to establish
// connections to either the backend server or the proxy (if it
// exists). The scheme of the dialed entity (either backend or
// proxy) determines which custom dial function is selected:
// either NetDialTLSContext for HTTPS or NetDialContext/NetDial
// for HTTP. Since the "Proxy" function can determine the scheme
// dynamically, it can make sense to set multiple custom dial
// functions simultaneously.
//
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dialer DialContext is used.
// If "Proxy" field is also set, this function dials the proxy--not
// the backend server.
NetDial func(network, addr string) (net.Conn, error)

// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, NetDial is used.
// If "Proxy" field is also set, this function dials the proxy--not
// the backend server.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)

// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
// If "Proxy" field is also set, this function dials the proxy (and performs
// the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy
// dialing case the TLSClientConfig could still be necessary for TLS to the backend server.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)

// Proxy specifies a function to return a proxy for a given
Expand All @@ -73,7 +89,7 @@ type Dialer struct {

// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// If NetDialTLSContext is set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config

Expand Down Expand Up @@ -244,49 +260,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}

var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
netDial = (&net.Dialer{}).DialContext
}

// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := forwardDial(ctx, network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}

// If needed, wrap the dial function to connect through a proxy.
var proxyURL *url.URL
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
proxyURL, err = d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
}
}
netDial, err := d.netDialFn(ctx, u, proxyURL)
if err != nil {
return nil, nil, err
}

hostPort, hostNoPort := hostPortNoPort(u)
Expand Down Expand Up @@ -317,9 +300,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()

if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done

// Do TLS handshake over "netConn" connection if necessary.
if d.needsTLSHandshake(u, proxyURL) {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
Expand Down Expand Up @@ -415,6 +397,103 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}

// Returns the dial function to establish the connection to either the backend
// server or the proxy (if it exists). Instead returns an error if one occurred.
func (d *Dialer) netDialFn(ctx context.Context, backendURL *url.URL, proxyURL *url.URL) (netDialerFunc, error) {
netDial := d.netDialFromScheme(backendURL.Scheme)
if proxyURL != nil {
netDial = d.netDialFromScheme(proxyURL.Scheme)
// Wrap proxy dial function to perform TLS handshake if necessary.
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, proxyURL)
}
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
netDial = netDialWithDeadline(netDial, deadline)
}
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
if proxyURL != nil {
return proxyFromURL(proxyURL, netDial)
}
return netDial, nil
}

// Returns function to create the connection depending on the Dialer's
// custom dialing functions and the passed scheme of entity connecting to.
func (d *Dialer) netDialFromScheme(scheme string) netDialerFunc {
var netDial netDialerFunc
switch {
case scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
}
default:
netDial = (&net.Dialer{}).DialContext
}
return netDial
}

// Returns true if a TLS handshake needs to be performed to the backend server
// after the connection has been established (since some dialing functions *also*
// perform TLS handshake).
func (d *Dialer) needsTLSHandshake(backendURL *url.URL, proxyURL *url.URL) bool {
if backendURL.Scheme != "https" {
return false
}
// If a proxy exists, we will always need to do a TLS handshake.
if proxyURL != nil {
return true
}
// Otherwise, we will need to do a TLS handshake to the backend only
// if it has not already been performed by NetDialTLSContext.
return d.NetDialTLSContext == nil
}

// Returns wrapped "netDial" function, performing TLS handshake after connecting.
func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc {
return func(ctx context.Context, unused, addr string) (net.Conn, error) {
// Creates TCP connection to addr using passed "netDial" function.
conn, err := netDial(ctx, "tcp", addr)
if err != nil {
return nil, err
}
cfg := cloneTLSConfig(tlsConfig)
if cfg.ServerName == "" {
_, hostNoPort := hostPortNoPort(u)
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(conn, cfg)
// Do the TLS handshake using TLSConfig over the wrapped connection.
err = doHandshake(ctx, tlsConn, cfg)
if err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
}
}

// Returns wrapped "netDial" function, setting passed deadline.
func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := netDial(ctx, network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}

func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
Expand Down
Loading

0 comments on commit 0b0f26a

Please # to comment.