package rest

import (
	"net/http"
	"net/url"
	"regexp"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/rs/zerolog"
	"golang.org/x/time/rate"
)

func NewLeakyBucketRatelimiter() *LeakyBucketRatelimiter {
	return &LeakyBucketRatelimiter{
		buckets:  make(map[string]*rate.Limiter),
		routeMap: make(map[string]string),
	}
}

type LeakyBucketRatelimiter struct {
	sync.RWMutex
	buckets  map[string]*rate.Limiter
	routeMap map[string]string
}

func (c *LeakyBucketRatelimiter) Request(httpClient HTTPClient, req *request) (*DiscordResponse, error) {
	url, err := url.Parse(req.path)
	if err != nil {
		return nil, err
	}
	bucket := c.GetBucket(url.Path)
	if bucket != nil {
		res := bucket.Reserve()
		if !res.OK() {
			return nil, ErrMaxRetriesExceeded
		}
		zerolog.Ctx(req.ctx).Debug().Msgf("Ratelimited request to %s, waiting %v", url.Path, res.Delay())
		time.Sleep(res.Delay())
	}

	resp, err := httpClient.Request(req)
	if err != nil {
		return nil, err
	}

	c.updateFromResponse(resp.Header, req.path)

	return resp, nil
}

func (c *LeakyBucketRatelimiter) bucketExists(name string) bool {
	c.RLock()
	defer c.RUnlock()
	_, ok := c.buckets[name]
	return ok
}

func (c *LeakyBucketRatelimiter) mappingExists(name string) bool {
	c.RLock()
	defer c.RUnlock()
	_, ok := c.routeMap[name]
	return ok
}

func (c *LeakyBucketRatelimiter) addBucket(name string, r int, count int) {
	c.Lock()
	defer c.Unlock()
	c.buckets[name] = rate.NewLimiter(rate.Limit(count/r), count)
	// Reserve a ticket since we JUST made a request
	c.buckets[name].Reserve()
}

func (c *LeakyBucketRatelimiter) addMapping(routeKey, bucket string) {
	c.Lock()
	defer c.Unlock()
	c.routeMap[routeKey] = bucket
}

func (c *LeakyBucketRatelimiter) GetBucket(path string) *rate.Limiter {
	c.RLock()
	defer c.RUnlock()
	routeKey := parseRoute(path)
	b, ok := c.buckets[c.routeMap[routeKey]]
	if !ok {
		return nil
	}
	return b
}

func (c *LeakyBucketRatelimiter) updateFromResponse(h http.Header, path string) {
	bucket := h.Get("X-RateLimit-Bucket")
	limitHeader := h.Get("X-RateLimit-Limit")
	resetAfter := h.Get("X-RateLimit-Reset-After")

	count, err := strconv.ParseInt(limitHeader, 10, 64)
	if err != nil {
		return
	}

	reset, err := strconv.ParseInt(resetAfter, 10, 64)
	if err != nil {
		return
	}

	if !c.bucketExists(bucket) {
		// Upon first request, limit and resetAfter should be actual values
		c.addBucket(bucket, int(reset), int(count))
	}

	routeKey := parseRoute(path)

	if !c.mappingExists(routeKey) {
		c.addMapping(routeKey, bucket)
	}
}

var snowRe = regexp.MustCompile(`\d{17,19}`)

func parseRoute(path string) string {
	splitPath := strings.Split(path, "/")
	includeNext := true
	routeKeyParts := []string{}
	for _, c := range splitPath[3:] {
		isSnowflake := snowRe.MatchString(c)
		if isSnowflake && includeNext {
			routeKeyParts = append(routeKeyParts, c)
			includeNext = false
		} else if !isSnowflake {
			routeKeyParts = append(routeKeyParts, c)
			if c == "channels" || c == "guilds" || c == "webhooks" {
				includeNext = true
			}
		}
	}
	return strings.Join(routeKeyParts, ":")
}