package http2 import ( "bufio" "bytes" "crypto/tls" "errors" "fmt" "io" "net" "os" "strconv" "sync" "sync/atomic" "time" "github.com/valyala/fasthttp" ) // ConnOpts defines the connection options. type ConnOpts struct { // PingInterval defines the interval in which the client will ping the server. // // An interval of <=0 will make the library to use DefaultPingInterval. Because ping intervals can't be disabled PingInterval time.Duration // DisablePingChecking ... DisablePingChecking bool // OnDisconnect is a callback that fires when the Conn disconnects. OnDisconnect func(c *Conn) } // Handshake performs an HTTP/2 handshake. That means, it will send // the preface if `preface` is true, send a settings frame and a // window update frame (for the connection's window). // TODO: explain more func Handshake(preface bool, bw *bufio.Writer, st *Settings, maxWin int32) error { if preface { err := WritePreface(bw) if err != nil { return err } } fr := AcquireFrameHeader() defer ReleaseFrameHeader(fr) // write the settings st2 := &Settings{} st.CopyTo(st2) fr.SetBody(st2) _, err := fr.WriteTo(bw) if err == nil { // then send a window update fr := AcquireFrameHeader() wu := AcquireFrame(FrameWindowUpdate).(*WindowUpdate) wu.SetIncrement(int(maxWin)) fr.SetBody(wu) _, err = fr.WriteTo(bw) if err == nil { err = bw.Flush() } ReleaseFrameHeader(fr) } return err } // Conn represents a raw HTTP/2 connection over TLS + TCP. type Conn struct { c net.Conn br *bufio.Reader bw *bufio.Writer enc *HPACK dec *HPACK nextID uint32 serverWindow int32 serverStreamWindow int32 maxWindow int32 currentWindow int32 openStreams int32 current Settings serverS Settings state connState closeRef uint32 reqQueued sync.Map in chan *Ctx out chan *FrameHeader pingInterval time.Duration unacks int disableAcks bool lastErr error onDisconnect func(*Conn) closed uint64 } // NewConn returns a new HTTP/2 connection. // To start using the connection you need to call Handshake. func NewConn(c net.Conn, opts ConnOpts) *Conn { nc := &Conn{ c: c, br: bufio.NewReaderSize(c, 4096), bw: bufio.NewWriterSize(c, maxFrameSize), enc: AcquireHPACK(), dec: AcquireHPACK(), nextID: 1, maxWindow: 1 << 20, currentWindow: 1 << 20, in: make(chan *Ctx, 128), out: make(chan *FrameHeader, 128), pingInterval: opts.PingInterval, disableAcks: opts.DisablePingChecking, onDisconnect: opts.OnDisconnect, } nc.current.SetMaxWindowSize(1 << 20) nc.current.SetPush(false) return nc } // Dialer allows creating HTTP/2 connections by specifying an address and tls configuration. type Dialer struct { // Addr is the server's address in the form: `host:port`. Addr string // TLSConfig is the tls configuration. // // If TLSConfig is nil, a default one will be defined on the Dial call. TLSConfig *tls.Config // PingInterval defines the interval in which the client will ping the server. // // An interval of 0 will make the library to use DefaultPingInterval. Because ping intervals can't be disabled. PingInterval time.Duration // NetDial defines the callback for establishing new connection to the host. // Default Dial is used if not set. NetDial fasthttp.DialFunc } func (d *Dialer) tryDial() (net.Conn, error) { if d.TLSConfig == nil || !func() bool { for _, proto := range d.TLSConfig.NextProtos { if proto == "h2" { return true } } return false }() { configureDialer(d) } var c net.Conn var err error if d.NetDial != nil { c, err = d.NetDial(d.Addr) if err != nil { return nil, err } } else { tcpAddr, err := net.ResolveTCPAddr("tcp", d.Addr) if err != nil { return nil, err } c, err = net.DialTCP("tcp", nil, tcpAddr) if err != nil { return nil, err } } tlsConn := tls.Client(c, d.TLSConfig) if err := tlsConn.Handshake(); err != nil { _ = c.Close() return nil, err } if tlsConn.ConnectionState().NegotiatedProtocol != "h2" { _ = c.Close() return nil, ErrServerSupport } return tlsConn, nil } // Dial creates an HTTP/2 connection or returns an error. // // An expected error is ErrServerSupport. func (d *Dialer) Dial(opts ConnOpts) (*Conn, error) { c, err := d.tryDial() if err != nil { return nil, err } nc := NewConn(c, opts) err = nc.Handshake() return nc, err } // SetOnDisconnect sets the callback that will fire when the HTTP/2 connection is closed. func (c *Conn) SetOnDisconnect(cb func(*Conn)) { c.onDisconnect = cb } // LastErr returns the last registered error in case the connection was closed by the server. func (c *Conn) LastErr() error { return c.lastErr } // Handshake will perform the necessary handshake to establish the connection // with the server. If an error is returned you can assume the TCP connection has been closed. func (c *Conn) Handshake() error { err := c.doHandshake() if err == nil { go c.writeLoop() go c.readLoop() } return err } func (c *Conn) doHandshake() error { var err error if err = Handshake(true, c.bw, &c.current, c.maxWindow-65535); err != nil { _ = c.c.Close() return err } var fr *FrameHeader if fr, err = ReadFrameFrom(c.br); err == nil && fr.Type() != FrameSettings { _ = c.c.Close() return fmt.Errorf("unexpected frame, expected settings, got %s", fr.Type()) } else if err == nil { st := fr.Body().(*Settings) if !st.IsAck() { st.CopyTo(&c.serverS) c.serverStreamWindow += int32(c.serverS.MaxWindowSize()) if st.HeaderTableSize() <= defaultHeaderTableSize { c.enc.SetMaxTableSize(st.HeaderTableSize()) } // reply back fr := AcquireFrameHeader() stRes := AcquireFrame(FrameSettings).(*Settings) stRes.SetAck(true) fr.SetBody(stRes) if _, err = fr.WriteTo(c.bw); err == nil { err = c.bw.Flush() } ReleaseFrameHeader(fr) } } if err != nil { _ = c.c.Close() } else { ReleaseFrameHeader(fr) } return err } // CanOpenStream returns whether the client will be able to open a new stream or not. func (c *Conn) CanOpenStream() bool { return atomic.LoadInt32(&c.openStreams) < int32(c.serverS.maxStreams) } // Closed indicates whether the connection is closed or not. func (c *Conn) Closed() bool { return atomic.LoadUint64(&c.closed) == 1 } // Close closes the connection gracefully, sending a GoAway message // and then closing the underlying TCP connection. func (c *Conn) Close() error { if !atomic.CompareAndSwapUint64(&c.closed, 0, 1) { return io.EOF } close(c.in) fr := AcquireFrameHeader() defer ReleaseFrameHeader(fr) ga := AcquireFrame(FrameGoAway).(*GoAway) ga.SetStream(0) ga.SetCode(NoError) fr.SetBody(ga) _, err := fr.WriteTo(c.bw) if err == nil { err = c.bw.Flush() } _ = c.c.Close() if c.onDisconnect != nil { c.onDisconnect(c) } return err } // Write queues the request to be sent to the server. // // Check if `c` has been previously closed before accessing this function. func (c *Conn) Write(r *Ctx) { c.in <- r } var ErrStreamNotReady = errors.New("stream hasn't been created") // Cancel will try to cancel the request. // // Cancel can only return ErrStreamNotReady when the cancel is performed before the stream is created. func (c *Conn) Cancel(ctx *Ctx) error { if atomic.LoadUint32(&ctx.streamID) == 0 { return ErrStreamNotReady } c.cancel(ctx) return nil } func (c *Conn) cancel(ctx *Ctx) { h := AcquireFrameHeader() h.SetStream( // TODO: use atomic here?? atomic.LoadUint32(&ctx.streamID)) fr := AcquireFrame(FrameResetStream).(*RstStream) fr.SetCode(StreamCanceled) h.SetBody(fr) c.out <- h } type WriteError struct { err error } func (we WriteError) Error() string { return fmt.Sprintf("writing error: %s", we.err) } func (we WriteError) Unwrap() error { return we.err } func (we WriteError) Is(target error) bool { return errors.Is(we.err, target) } func (we WriteError) As(target interface{}) bool { return errors.As(we.err, target) } func (c *Conn) writeLoop() { var lastErr error defer func() { _ = c.Close() }() defer func() { if err := recover(); err != nil { if lastErr == nil { switch errn := err.(type) { case error: lastErr = errn case string: lastErr = errors.New(errn) } } } if lastErr == nil { lastErr = io.ErrUnexpectedEOF } c.reqQueued.Range(func(_, v interface{}) bool { r := v.(*Ctx) r.resolve(lastErr) return true }) }() if c.pingInterval <= 0 { c.pingInterval = DefaultPingInterval } ticker := time.NewTicker(c.pingInterval) defer ticker.Stop() loop: for { select { case ctx, ok := <-c.in: // sending requests if !ok { break loop } err := c.writeRequest(ctx) if err != nil { ctx.resolve(err) if errors.Is(err, ErrNotAvailableStreams) { continue } lastErr = WriteError{err} break loop } case fr, ok := <-c.out: // generic output if !ok { break loop } err := c.writeFrame(fr) if err != nil { lastErr = WriteError{err} break loop } ReleaseFrameHeader(fr) case <-ticker.C: // ping if err := c.writePing(); err != nil { lastErr = WriteError{err} break loop } } if !c.disableAcks && c.unacks >= 3 { lastErr = ErrTimeout break loop } } } func (c *Conn) writeFrame(fr *FrameHeader) error { _, err := fr.WriteTo(c.bw) if err == nil { if err = c.bw.Flush(); err != nil { return err } } return err } func (c *Conn) finish(r *Ctx, stream uint32, err error) { atomic.AddInt32(&c.openStreams, -1) r.resolve(err) c.reqQueued.Delete(stream) } func (c *Conn) readLoop() { defer func() { _ = c.Close() }() for { fr, err := c.readNext() if err != nil { c.lastErr = err break } // TODO: panic otherwise? if ri, ok := c.reqQueued.Load(fr.Stream()); ok { r := ri.(*Ctx) err := c.readStream(fr, r.Response) if err == nil { if fr.Flags().Has(FlagEndStream) { c.finish(r, fr.Stream(), nil) } } else { c.finish(r, fr.Stream(), err) fmt.Fprintf(os.Stderr, "%s. payload=%v\n", err, fr.payload) if errors.Is(err, FlowControlError) { break } } if c.state == connStateClosed { if fr.Stream() == c.closeRef { break } } } ReleaseFrameHeader(fr) } } func (c *Conn) writeRequest(ctx *Ctx) error { if !c.CanOpenStream() { return ErrNotAvailableStreams } req := ctx.Request hasBody := len(req.Body()) != 0 enc := c.enc id := c.nextID c.nextID += 2 fr := AcquireFrameHeader() defer ReleaseFrameHeader(fr) fr.SetStream(id) h := AcquireFrame(FrameHeaders).(*Headers) fr.SetBody(h) hf := AcquireHeaderField() hf.SetBytes(StringAuthority, req.URI().Host()) enc.AppendHeaderField(h, hf, true) hf.SetBytes(StringMethod, req.Header.Method()) enc.AppendHeaderField(h, hf, true) hf.SetBytes(StringPath, req.URI().RequestURI()) enc.AppendHeaderField(h, hf, true) hf.SetBytes(StringScheme, req.URI().Scheme()) enc.AppendHeaderField(h, hf, true) hf.SetBytes(StringUserAgent, req.Header.UserAgent()) enc.AppendHeaderField(h, hf, true) req.Header.VisitAll(func(k, v []byte) { if bytes.EqualFold(k, StringUserAgent) { return } hf.SetBytes(ToLower(k), v) enc.AppendHeaderField(h, hf, false) }) h.SetPadding(false) h.SetEndStream(!hasBody) h.SetEndHeaders(true) // store the ctx before sending the request atomic.StoreUint32(&ctx.streamID, id) c.reqQueued.Store(id, ctx) _, err := fr.WriteTo(c.bw) if err == nil && hasBody { // release headers bc it's going to get replaced by the data frame ReleaseFrame(h) err = writeData(c.bw, fr, req.Body()) } if err == nil { err = c.bw.Flush() if err == nil { atomic.AddInt32(&c.openStreams, 1) } } if err != nil { c.lastErr = err // if we had any error, remove it from the reqQueued. c.reqQueued.Delete(id) } ReleaseHeaderField(hf) return err } func writeData(bw *bufio.Writer, fh *FrameHeader, body []byte) (err error) { step := 1 << 14 data := AcquireFrame(FrameData).(*Data) fh.SetBody(data) for i := 0; err == nil && i < len(body); i += step { if i+step >= len(body) { step = len(body) - i } data.SetEndStream(i+step == len(body)) data.SetPadding(false) data.SetData(body[i : step+i]) _, err = fh.WriteTo(bw) } return err } func (c *Conn) readNext() (fr *FrameHeader, err error) { loop: for err == nil { fr, err = ReadFrameFrom(c.br) if err != nil { break } if fr.Stream() != 0 { break } switch fr.Type() { case FrameSettings: st := fr.Body().(*Settings) if !st.IsAck() { // if it has ack, just ignore c.handleSettings(st) } case FrameWindowUpdate: win := int32(fr.Body().(*WindowUpdate).Increment()) atomic.AddInt32(&c.serverWindow, win) case FramePing: ping := fr.Body().(*Ping) if !ping.IsAck() { c.handlePing(ping) } else { c.unacks-- } case FrameGoAway: ga := fr.Body().(*GoAway) if ga.stream == 0 { _ = c.c.Close() err = ga } else { // wait for the streams to complete c.closeRef = ga.stream c.state = connStateClosed } break loop } ReleaseFrameHeader(fr) } return } var ErrTimeout = errors.New("server is not replying to pings") func (c *Conn) writePing() error { fr := AcquireFrameHeader() defer ReleaseFrameHeader(fr) ping := AcquireFrame(FramePing).(*Ping) ping.SetCurrentTime() fr.SetBody(ping) _, err := fr.WriteTo(c.bw) if err == nil { err = c.bw.Flush() if err == nil { c.unacks++ } } return err } func (c *Conn) handleSettings(st *Settings) { st.CopyTo(&c.serverS) c.serverStreamWindow += int32(c.serverS.MaxWindowSize()) c.enc.SetMaxTableSize(st.HeaderTableSize()) // reply back fr := AcquireFrameHeader() stRes := AcquireFrame(FrameSettings).(*Settings) stRes.SetAck(true) fr.SetBody(stRes) c.out <- fr } func (c *Conn) handlePing(ping *Ping) { // reply back fr := AcquireFrameHeader() ping.SetAck(true) fr.SetBody(ping) c.out <- fr } func (c *Conn) readStream(fr *FrameHeader, res *fasthttp.Response) (err error) { switch fr.Type() { case FrameHeaders, FrameContinuation: h := fr.Body().(FrameWithHeaders) err = c.readHeader(h.Headers(), res) case FrameData: c.currentWindow -= int32(fr.Len()) currentWin := c.currentWindow c.serverWindow -= int32(fr.Len()) data := fr.Body().(*Data) if data.Len() != 0 { res.AppendBody(data.Data()) // let's send the window update c.updateWindow(fr.Stream(), fr.Len()) } if currentWin < c.maxWindow/2 { nValue := c.maxWindow - currentWin c.currentWindow = c.maxWindow c.updateWindow(0, int(nValue)) } } return } func (c *Conn) updateWindow(streamID uint32, size int) { fr := AcquireFrameHeader() fr.SetStream(streamID) wu := AcquireFrame(FrameWindowUpdate).(*WindowUpdate) wu.SetIncrement(size) fr.SetBody(wu) c.out <- fr } func (c *Conn) readHeader(b []byte, res *fasthttp.Response) error { var err error hf := AcquireHeaderField() defer ReleaseHeaderField(hf) dec := c.dec for len(b) > 0 { b, err = dec.Next(hf, b) if err != nil { return err } if hf.IsPseudo() { if hf.KeyBytes()[1] == 's' { // status n, err := strconv.ParseInt(hf.Value(), 10, 64) if err != nil { return err } res.SetStatusCode(int(n)) continue } } if bytes.Equal(hf.KeyBytes(), StringContentLength) { n, _ := strconv.Atoi(hf.Value()) res.Header.SetContentLength(n) } else { res.Header.AddBytesKV(hf.KeyBytes(), hf.ValueBytes()) } } return nil }