diff --git a/.travis.yml b/.travis.yml index a90dc99..ee0952f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,6 @@ before_install: - go get golang.org/x/tools/cmd/cover script: + - go vet ./... - go test -v -covermode=count -coverprofile=coverage.out diff --git a/README.md b/README.md index d1da701..a39158e 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,11 @@ package main import "github.com/dinever/golf" -func mainHandler(ctx *Golf.Context) { +func mainHandler(ctx *golf.Context) { ctx.Write("Hello World!") } -func pageHandler(ctx *Golf.Context) { +func pageHandler(ctx *golf.Context) { page, err := ctx.Param("page") if err != nil { ctx.Abort(500) @@ -65,7 +65,7 @@ func pageHandler(ctx *Golf.Context) { } func main() { - app := Golf.New() + app := golf.New() app.Get("/", mainHandler) app.Get("/p/:page/", pageHandler) app.Run(":9000") diff --git a/app.go b/app.go index ddb6f8f..e5251df 100644 --- a/app.go +++ b/app.go @@ -1,10 +1,11 @@ -package Golf +package golf import ( "net/http" "os" "path" "strings" + "sync" ) // Application is an abstraction of a Golf application, can be used for @@ -25,16 +26,20 @@ type Application struct { SessionManager SessionManager // NotFoundHandler handles requests when no route is matched. - NotFoundHandler Handler + NotFoundHandler HandlerFunc // MiddlewareChain is the default middlewares that Golf uses. MiddlewareChain *Chain - errorHandler map[int]ErrorHandlerType + pool sync.Pool + + errorHandler map[int]ErrorHandlerFunc // The default error handler, if the corresponding error code is not specified // in the `errorHandler` map, this handler will be called. - DefaultErrorHandler ErrorHandlerType + DefaultErrorHandler ErrorHandlerFunc + + handlerChain HandlerFunc } // New is used for creating a new Golf Application instance. @@ -43,11 +48,14 @@ func New() *Application { app.router = newRouter() app.staticRouter = make(map[string][]string) app.View = NewView() - app.Config = NewConfig(app) + app.Config = NewConfig() // debug, _ := app.Config.GetBool("debug", false) - app.errorHandler = make(map[int]ErrorHandlerType) - app.MiddlewareChain = NewChain(defaultMiddlewares...) + app.errorHandler = make(map[int]ErrorHandlerFunc) + app.MiddlewareChain = NewChain() app.DefaultErrorHandler = defaultErrorHandler + app.pool.New = func() interface{} { + return new(Context) + } return app } @@ -67,12 +75,12 @@ func (app *Application) handler(ctx *Context) { } } - params, handler := app.router.match(ctx.Request.URL.Path, ctx.Request.Method) - if handler != nil { + handler, params, err := app.router.FindRoute(ctx.Request.Method, ctx.Request.URL.Path) + if err != nil { + app.handleError(ctx, 404) + } else { ctx.Params = params handler(ctx) - } else { - app.handleError(ctx, 404) } ctx.Send() } @@ -84,8 +92,16 @@ func staticHandler(ctx *Context, filePath string) { // Basic entrance of an `http.ResponseWriter` and an `http.Request`. func (app *Application) ServeHTTP(res http.ResponseWriter, req *http.Request) { - ctx := NewContext(req, res, app) - app.MiddlewareChain.Final(app.handler)(ctx) + if app.handlerChain == nil { + app.handlerChain = app.MiddlewareChain.Final(app.handler) + } + ctx := app.pool.Get().(*Context) + ctx.reset() + ctx.Request = req + ctx.Response = res + ctx.App = app + app.handlerChain(ctx) + app.pool.Put(ctx) } // Run the Golf Application. @@ -111,27 +127,42 @@ func (app *Application) Static(url string, path string) { } // Get method is used for registering a Get method route -func (app *Application) Get(pattern string, handler Handler) { - app.router.get(pattern, handler) +func (app *Application) Get(pattern string, handler HandlerFunc) { + app.router.AddRoute("GET", pattern, handler) } // Post method is used for registering a Post method route -func (app *Application) Post(pattern string, handler Handler) { - app.router.post(pattern, handler) +func (app *Application) Post(pattern string, handler HandlerFunc) { + app.router.AddRoute("POST", pattern, handler) } // Put method is used for registering a Put method route -func (app *Application) Put(pattern string, handler Handler) { - app.router.put(pattern, handler) +func (app *Application) Put(pattern string, handler HandlerFunc) { + app.router.AddRoute("PUT", pattern, handler) } // Delete method is used for registering a Delete method route -func (app *Application) Delete(pattern string, handler Handler) { - app.router.delete(pattern, handler) +func (app *Application) Delete(pattern string, handler HandlerFunc) { + app.router.AddRoute("DELETE", pattern, handler) +} + +// Patch method is used for registering a Patch method route +func (app *Application) Patch(pattern string, handler HandlerFunc) { + app.router.AddRoute("PATCH", pattern, handler) +} + +// Options method is used for registering a Options method route +func (app *Application) Options(pattern string, handler HandlerFunc) { + app.router.AddRoute("OPTIONS", pattern, handler) +} + +// Head method is used for registering a Head method route +func (app *Application) Head(pattern string, handler HandlerFunc) { + app.router.AddRoute("HEAD", pattern, handler) } // Error method is used for registering an handler for a specified HTTP error code. -func (app *Application) Error(statusCode int, handler ErrorHandlerType) { +func (app *Application) Error(statusCode int, handler ErrorHandlerFunc) { app.errorHandler[statusCode] = handler } diff --git a/config.go b/config.go index 80004e8..0359353 100644 --- a/config.go +++ b/config.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "encoding/json" @@ -38,61 +38,61 @@ type Config struct { } // NewConfig creates a new configuration instance. -func NewConfig(app *Application) *Config { +func NewConfig() *Config { mapping := make(map[string]interface{}) return &Config{mapping} } // GetString fetches the string value by indicating the key. // It returns a ValueTypeError if the value is not a sring. -func (config *Config) GetString(key string, defaultValue interface{}) (string, error) { +func (config *Config) GetString(key string, defaultValue string) (string, error) { value, err := config.Get(key, defaultValue) if err != nil { - return "", err + return defaultValue, err } if result, ok := value.(string); ok { return result, nil } - return "", &ValueTypeError{key: key, value: value, message: "Value is not a string."} + return defaultValue, &ValueTypeError{key: key, value: value, message: "Value is not a string."} } // GetInt fetches the int value by indicating the key. // It returns a ValueTypeError if the value is not a sring. -func (config *Config) GetInt(key string, defaultValue interface{}) (int, error) { +func (config *Config) GetInt(key string, defaultValue int) (int, error) { value, err := config.Get(key, defaultValue) if err != nil { - return 0, err + return defaultValue, err } if result, ok := value.(int); ok { return result, nil } - return 0, &ValueTypeError{key: key, value: value, message: "Value is not an integer."} + return defaultValue, &ValueTypeError{key: key, value: value, message: "Value is not an integer."} } // GetBool fetches the bool value by indicating the key. // It returns a ValueTypeError if the value is not a sring. -func (config *Config) GetBool(key string, defaultValue interface{}) (bool, error) { +func (config *Config) GetBool(key string, defaultValue bool) (bool, error) { value, err := config.Get(key, defaultValue) if err != nil { - return false, err + return defaultValue, err } if result, ok := value.(bool); ok { return result, nil } - return false, &ValueTypeError{key: key, value: value, message: "Value is not an bool."} + return defaultValue, &ValueTypeError{key: key, value: value, message: "Value is not an bool."} } // GetFloat fetches the float value by indicating the key. // It returns a ValueTypeError if the value is not a sring. -func (config *Config) GetFloat(key string, defaultValue interface{}) (float64, error) { +func (config *Config) GetFloat(key string, defaultValue float64) (float64, error) { value, err := config.Get(key, defaultValue) if err != nil { - return 0, err + return defaultValue, err } if result, ok := value.(float64); ok { return result, nil } - return 0, &ValueTypeError{key: key, value: value, message: "Value is not a float."} + return defaultValue, &ValueTypeError{key: key, value: value, message: "Value is not a float."} } // Set is used to set the value by indicating the key. @@ -151,7 +151,7 @@ func (config *Config) Get(key string, defaultValue interface{}) (interface{}, er if value, exists := mapping[item]; exists { tmp = value } else if defaultValue != nil { - return defaultValue, nil + return nil, &KeyError{key: path.Join(append(keys[:i], item)...)} } else { return nil, &KeyError{key: path.Join(append(keys[:i], item)...)} } diff --git a/config_test.go b/config_test.go index eb0302e..8bfe33e 100644 --- a/config_test.go +++ b/config_test.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "bytes" @@ -22,8 +22,7 @@ func TestConfig(t *testing.T) { defaultValue := "None" for _, c := range cases { - app := New() - config := NewConfig(app) + config := NewConfig() config.Set(c.key, c.value) value, err := config.Get(c.key, defaultValue) @@ -46,8 +45,7 @@ func TestConfigWithMultipleEntires(t *testing.T) { {"foo2", "bar4"}, } - app := New() - config := NewConfig(app) + config := NewConfig() for _, c := range settings { config.Set(c.key, c.value) @@ -80,3 +78,91 @@ func TestFromJSON(t *testing.T) { t.Errorf("expected value to be abc but it was %v", value) } } + +func TestGetStringException(t *testing.T) { + defaultValue := "None" + + config := NewConfig() + config.Set("foo", 123) + val, err := config.GetString("foo", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-string value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } + + val, err = config.GetString("bar", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-existed value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } +} + +func TestGetIntegerException(t *testing.T) { + defaultValue := 123 + + config := NewConfig() + config.Set("foo", "bar") + val, err := config.GetInt("foo", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-string value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } + + val, err = config.GetInt("bar", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-existed value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } +} + +func TestGetBoolException(t *testing.T) { + defaultValue := false + + config := NewConfig() + config.Set("foo", "bar") + val, err := config.GetBool("foo", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-string value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } + + val, err = config.GetBool("bar", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-existed value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } +} + +func TestGetFloat64Exception(t *testing.T) { + defaultValue := 0.5 + + config := NewConfig() + config.Set("foo", "bar") + val, err := config.GetFloat("foo", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-string value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } + + val, err = config.GetFloat("bar", defaultValue) + if err == nil { + t.Errorf("Should have raised an type error when getting a non-existed value by GetString.") + } + if val != defaultValue { + t.Errorf("Should have used the default value when raising an error.") + } +} diff --git a/context.go b/context.go index b2d2d0d..1ab8602 100644 --- a/context.go +++ b/context.go @@ -1,12 +1,12 @@ -package Golf +package golf import ( "encoding/hex" "encoding/json" "errors" + "fmt" "net/http" "time" - "fmt" ) // Context is a wrapper of http.Request and http.ResponseWriter. @@ -18,7 +18,7 @@ type Context struct { Response http.ResponseWriter // URL Parameter - Params map[string]string + Params Parameter // HTTP status code StatusCode int @@ -60,6 +60,11 @@ func NewContext(req *http.Request, res http.ResponseWriter, app *Application) *C return ctx } +func (ctx *Context) reset() { + ctx.StatusCode = 200 + ctx.IsSent = false +} + func (ctx *Context) generateSession() Session { s, err := ctx.App.SessionManager.NewSession() if err != nil { @@ -98,10 +103,8 @@ func (ctx *Context) Query(key string, index ...int) (string, error) { // Param method retrieves the parameters from url // If the url is /:id/, then id can be retrieved by calling `ctx.Param(id)` func (ctx *Context) Param(key string) (string, error) { - if val, ok := ctx.Params[key]; ok { - return val, nil - } - return "", errors.New("Parameter not found.") + val, err := ctx.Params.ByName(key) + return val, err } // Redirect method sets the response as a 301 redirection. diff --git a/context_test.go b/context_test.go index c17aeb4..27a1e3c 100644 --- a/context_test.go +++ b/context_test.go @@ -1,7 +1,8 @@ -package Golf +package golf import ( "bufio" + "encoding/hex" "fmt" "io" "net/http" @@ -9,7 +10,6 @@ import ( "reflect" "strings" "testing" - "encoding/hex" ) func makeTestHTTPRequest(body io.Reader, method, url string) *http.Request { diff --git a/error.go b/error.go index 93e61f9..6e40a3f 100644 --- a/error.go +++ b/error.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "bytes" diff --git a/middleware.go b/middleware.go index 71e0e17..abd9ac1 100644 --- a/middleware.go +++ b/middleware.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "log" @@ -6,7 +6,7 @@ import ( "time" ) -type middlewareHandler func(next Handler) Handler +type middlewareHandler func(next HandlerFunc) HandlerFunc var defaultMiddlewares = []middlewareHandler{LoggingMiddleware, RecoverMiddleware, XSRFProtectionMiddleware, SessionMiddleware} @@ -24,12 +24,11 @@ func NewChain(handlerArray ...middlewareHandler) *Chain { // Final indicates a final Handler, chain the multiple middlewares together with the // handler, and return them together as a handler. -func (c Chain) Final(fn Handler) Handler { - final := fn +func (c Chain) Final(fn HandlerFunc) HandlerFunc { for i := len(c.middlewareHandlers) - 1; i >= 0; i-- { - final = c.middlewareHandlers[i](final) + fn = c.middlewareHandlers[i](fn) } - return final + return fn } // Append a middleware to the middleware chain @@ -38,7 +37,7 @@ func (c *Chain) Append(fn middlewareHandler) { } // LoggingMiddleware is the built-in middleware for logging. -func LoggingMiddleware(next Handler) Handler { +func LoggingMiddleware(next HandlerFunc) HandlerFunc { fn := func(ctx *Context) { t1 := time.Now() next(ctx) @@ -49,7 +48,7 @@ func LoggingMiddleware(next Handler) Handler { } // XSRFProtectionMiddleware is the built-in middleware for XSRF protection. -func XSRFProtectionMiddleware(next Handler) Handler { +func XSRFProtectionMiddleware(next HandlerFunc) HandlerFunc { fn := func(ctx *Context) { xsrfEnabled, _ := ctx.App.Config.GetBool("xsrf_cookies", false) if xsrfEnabled && (ctx.Request.Method == "POST" || ctx.Request.Method == "PUT" || ctx.Request.Method == "DELETE") { @@ -64,7 +63,7 @@ func XSRFProtectionMiddleware(next Handler) Handler { } // RecoverMiddleware is the built-in middleware for recovering from errors. -func RecoverMiddleware(next Handler) Handler { +func RecoverMiddleware(next HandlerFunc) HandlerFunc { fn := func(ctx *Context) { defer func() { if err := recover(); err != nil { @@ -86,7 +85,8 @@ func RecoverMiddleware(next Handler) Handler { return fn } -func SessionMiddleware(next Handler) Handler { +// SessionMiddleware handles session of the request +func SessionMiddleware(next HandlerFunc) HandlerFunc { fn := func(ctx *Context) { if ctx.App.SessionManager != nil { ctx.retrieveSession() diff --git a/router.go b/router.go index c39c289..f35f744 100644 --- a/router.go +++ b/router.go @@ -1,103 +1,147 @@ -package Golf +package golf import ( - "regexp" - "strings" + "fmt" ) -const ( - routerMethodGet = "GET" - routerMethodPost = "POST" - routerMethodPut = "PUT" - routerMethodDelete = "DELETE" -) +// HandlerFunc is the type of the handler function that Golf accepts. +type HandlerFunc func(ctx *Context) + +// ErrorHandlerFunc is the type of the function that handles error in Golf. +type ErrorHandlerFunc func(ctx *Context, data ...map[string]interface{}) type router struct { - routeSlice []*route + trees map[string]*node } -// Handler is the type of the handler function that Golf accepts. -type Handler func(ctx *Context) +func newRouter() *router { + return &router{trees: make(map[string]*node)} +} -// ErrorHandlerType is the type of the function that handles error in Golf. -type ErrorHandlerType func(ctx *Context, data ...map[string]interface{}) +func splitURLpath(path string) (parts []string, names map[string]int) { -// ErrorHandler is the type of the error handler function that Golf accepts. -type ErrorHandler func(ctx *Context, e error) + var ( + nameidx = -1 + partidx int + paramCounter int + ) -type route struct { - method string - pattern string - regex *regexp.Regexp - params []string - handler Handler -} + for i := 0; i < len(path); i++ { -func newRouter() *router { - r := new(router) - r.routeSlice = make([]*route, 0) - return r -} + if names == nil { + names = make(map[string]int) + } + // recording name + if nameidx != -1 { + //found / + if path[i] == '/' { + names[path[nameidx:i]] = paramCounter + paramCounter++ + + nameidx = -1 // switch to normal recording + partidx = i + } + } else { + if path[i] == ':' || path[i] == '*' { + if path[i-1] != '/' { + panic(fmt.Errorf("Invalid parameter : or * should always be after / - %q", path)) + } + nameidx = i + 1 + if partidx != i { + parts = append(parts, path[partidx:i]) + } + parts = append(parts, path[i:nameidx]) + } + } + } -func newRoute(method string, pattern string, handler Handler) *route { - route := new(route) - route.pattern = pattern - route.params = make([]string, 0) - route.regex, route.params = route.parseURL(pattern) - route.method = method - route.handler = handler - return route + if nameidx != -1 { + names[path[nameidx:]] = paramCounter + paramCounter++ + } else if partidx < len(path) { + parts = append(parts, path[partidx:]) + } + return } -func (router *router) get(pattern string, handler Handler) { - route := newRoute(routerMethodGet, pattern, handler) - router.registerRoute(route, handler) +func (router *router) FindRoute(method string, path string) (HandlerFunc, Parameter, error) { + node := router.trees[method] + if node == nil { + return nil, Parameter{}, fmt.Errorf("Can not find route") + } + matchedNode, err := node.findRoute(path) + if err != nil { + return nil, Parameter{}, err + } + return matchedNode.handler, Parameter{node: matchedNode, path: path}, err } -func (router *router) post(pattern string, handler Handler) { - route := newRoute(routerMethodPost, pattern, handler) - router.registerRoute(route, handler) +func (router *router) AddRoute(method string, path string, handler HandlerFunc) { + var ( + rootNode *node + ok bool + ) + parts, names := splitURLpath(path) + if rootNode, ok = router.trees[method]; !ok { + rootNode = &node{} + router.trees[method] = rootNode + } + rootNode.addRoute(parts, names, handler) + rootNode.optimizeRoutes() } -func (router *router) put(pattern string, handler Handler) { - route := newRoute(routerMethodPut, pattern, handler) - router.registerRoute(route, handler) +//Parameter holds the parameters matched in the route +type Parameter struct { + *node // matched node + path string // url path given + cached map[string]string } -func (router *router) delete(pattern string, handler Handler) { - route := newRoute(routerMethodDelete, pattern, handler) - router.registerRoute(route, handler) +//Len returns number arguments matched in the provided URL +func (p *Parameter) Len() int { + return len(p.names) } -func (router *router) registerRoute(route *route, handler Handler) { - router.routeSlice = append(router.routeSlice, route) +//ByName returns the url parameter by name +func (p *Parameter) ByName(name string) (string, error) { + if i, has := p.names[name]; has { + return p.findParam(i) + } + return "", fmt.Errorf("Parameter not found") } -func (router *router) match(url string, method string) (params map[string]string, handler Handler) { - params = make(map[string]string) - for _, route := range router.routeSlice { - if method == route.method && route.regex.MatchString(url) { - subMatch := route.regex.FindStringSubmatch(url) - for i, param := range route.params { - params[param] = subMatch[i+1] - } - handler = route.handler - return params, handler +func lastIndexByte(s string, c byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == c { + return i } } - return nil, nil + return -1 } -// Parse the URL to a regexp and a map of parameters -func (route *route) parseURL(pattern string) (regex *regexp.Regexp, params []string) { - params = make([]string, 0) - segments := strings.Split(pattern, "/") - for i, segment := range segments { - if strings.HasPrefix(segment, ":") { - segments[i] = `([\w-%]+)` - params = append(params, strings.TrimPrefix(segment, ":")) +//findParam walks up the matched node looking for parameters returns the last parameter +func (p *Parameter) findParam(idx int) (string, error) { + index := len(p.names) - 1 + urlPath := p.path + pathLen := len(p.path) + node := p.node + + for node != nil { + if node.text[0] == ':' { + ctn := lastIndexByte(urlPath, '/') + if ctn == -1 { + return "", fmt.Errorf("Parameter not found") + } + pathLen = ctn + 1 + if index == idx { + return urlPath[pathLen:], nil + } + index-- + } else { + pathLen -= len(node.text) } + urlPath = urlPath[0:pathLen] + node = node.parent } - regex, _ = regexp.Compile("^" + strings.Join(segments, "/") + "$") - return regex, params + return "", fmt.Errorf("Parameter not found") } diff --git a/router_test.go b/router_test.go index 486b91c..ee85174 100644 --- a/router_test.go +++ b/router_test.go @@ -1,82 +1,135 @@ -package Golf +package golf import ( - "reflect" "testing" ) -func handler(ctx *Context) {} - -func TestParsePatternWithOneParam(t *testing.T) { - cases := []struct { - method, in, regex, param string - }{ - {routerMethodGet, "/:id/", `^/([\w-%]+)/$`, "id"}, - {routerMethodPost, "/:id/", `^/([\w-%]+)/$`, "id"}, - {routerMethodPut, "/:id/", `^/([\w-%]+)/$`, "id"}, - {routerMethodDelete, "/:id/", `^/([\w-%]+)/$`, "id"}, +func assertStringEqual(t *testing.T, expected, got string) { + if expected != got { + t.Errorf("Expected %v, got %v", expected, got) } +} - for _, c := range cases { - route := newRoute(c.method, c.in, handler) - if route.regex.String() != c.regex { - t.Errorf("regex of %q == %q, want %q", c.in, route.regex.String(), c.regex) - } - if len(route.params) != 1 { - t.Errorf("%q is supposed to have 1 parameter", c.in) - } - if route.params[0] != "id" { - t.Errorf("params[0] == %q, want %q", c.in, c.param) +func assertSliceEqual(t *testing.T, expected, got []string) { + if len(expected) != len(got) { + t.Errorf("Slice length not equal, expected: %v, got %v", expected, got) + } + for i := 0; i < len(expected); i++ { + if expected[i] != got[i] { + t.Errorf("Slice not equal, expected: %v, got %v", expected, got) } } } -func TestParsePatternWithThreeParam(t *testing.T) { - cases := []struct { - in, regex string - params []string - }{ - { - "/:year/:month/:day/", - `^/([\w-%]+)/([\w-%]+)/([\w-%]+)/$`, - []string{"year", "month", "day"}, - }, +type route struct { + method string + path string + testPath string + params map[string]string +} + +var githubAPI = []route{ + // OAuth Authorizations + {"GET", "/authorizations", "/authorizations", map[string]string{}}, + {"GET", "/auth", "/auth", map[string]string{}}, + {"GET", "/authorizations/:id", "/authorizations/12345", map[string]string{"id": "12345"}}, + {"POST", "/authorizations", "/authorizations", map[string]string{}}, + {"DELETE", "/authorizations/:id", "/authorizations/12345", map[string]string{"id": "12345"}}, + {"GET", "/applications/:client_id/tokens/:access_token", "/applications/12345/tokens/67890", map[string]string{"client_id": "12345", "access_token": "67890"}}, + {"DELETE", "/applications/:client_id/tokens", "/applications/12345/tokens", map[string]string{"client_id": "12345"}}, + {"DELETE", "/applications/:client_id/tokens/:access_token", "/applications/12345/tokens/67890", map[string]string{"client_id": "12345", "access_token": "67890"}}, + + // Activity + {"GET", "/events", "/events", nil}, + {"GET", "/repos/:owner/:repo/events", "/repos/dinever/golf/events", map[string]string{"owner": "dinever", "repo": "golf"}}, + {"GET", "/networks/:owner/:repo/events", "/networks/dinever/golf/events", map[string]string{"owner": "dinever", "repo": "golf"}}, + {"GET", "/orgs/:org/events", "/orgs/golf/events", map[string]string{"org": "golf"}}, + {"GET", "/users/:user/received_events", "/users/dinever/received_events", nil}, + {"GET", "/users/:user/received_events/public", "/users/dinever/received_events/public", nil}, +} + +func handler(ctx *Context) { +} + +func TestRouter(t *testing.T) { + router := newRouter() + for _, route := range githubAPI { + router.AddRoute(route.method, route.path, handler) } - for _, c := range cases { - route := newRoute(routerMethodGet, c.in, handler) - if route.regex.String() != c.regex { - t.Errorf("regex == %q, want %q", c.in, route.regex.String()) + for _, route := range githubAPI { + _, param, err := router.FindRoute(route.method, route.testPath) + if err != nil { + t.Errorf("Can not find route: %v", route.testPath) } - if !reflect.DeepEqual(route.params, c.params) { - t.Errorf("parameters not match: %v != %v", route.params, c.params) + + for key, expected := range route.params { + val, err := param.ByName(key) + if err != nil { + t.Errorf("Can not retrieve parameter from route %v: %v", route.testPath, key) + } else { + assertStringEqual(t, expected, val) + } + val, err = param.ByName(key) + if err != nil { + t.Errorf("Can not retrieve parameter from route %v: %v", route.testPath, key) + } else { + assertStringEqual(t, expected, val) + } } } } -func TestRouterMatch(t *testing.T) { +func TestSplitURLPath(t *testing.T) { + + var table = map[string][2][]string{ + "/users/:name": {{"/users/", ":"}, {"name"}}, + "/users/:name/put": {{"/users/", ":", "/put"}, {"name"}}, + "/users/:name/put/:section": {{"/users/", ":", "/put/", ":"}, {"name", "section"}}, + "/customers/:name/put/:section": {{"/customers/", ":", "/put/", ":"}, {"name", "section"}}, + "/customers/groups/:name/put/:section": {{"/customers/groups/", ":", "/put/", ":"}, {"name", "section"}}, + } + + for path, result := range table { + parts, _ := splitURLpath(path) + assertSliceEqual(t, parts, result[0]) + } +} + +func TestIncorrectPath(t *testing.T) { + path := "/users/foo:name/" + defer func() { + if err := recover(); err != nil { + } + }() router := newRouter() - cases := []struct { - pattern string - url string - params map[string]string + router.AddRoute("GET", path, handler) + t.Errorf("Incorrect path should raise an error.") +} + +func TestPathNotFound(t *testing.T) { + path := []struct { + method, path, incomingMethod, incomingPath string }{ - { - "/:year/:month/:day/", - "/2015/11/15/", - map[string]string{"year": "2015", "month": "11", "day": "15"}, - }, - { - "/user/:id/", - "/user/foobar/", - map[string]string{"id": "foobar"}, - }, + {"GET", "/users/name/", "GET", "/users/name/dinever/"}, + {"GET", "/dinever/repo/", "POST", "/dinever/repo/"}, } - for _, c := range cases { - router.get(c.pattern, handler) - params, _ := router.match(c.url, routerMethodGet) - if !reflect.DeepEqual(params, c.params) { - t.Errorf("parameters not match: %v != %v", params, c.params) + defer func() { + if err := recover(); err != nil { + } + }() + router := newRouter() + for _, path := range path { + router.AddRoute(path.method, path.path, handler) + h, p, err := router.FindRoute(path.incomingMethod, path.incomingPath) + if h != nil { + t.Errorf("Should return nil handler when path not found.") + } + if p.Len() != 0 { + t.Errorf("Should return nil parameter when path not found.") + } + if err == nil { + t.Errorf("Should rasie an error when path not found.") } } } diff --git a/session.go b/session.go index a623c1a..ff8b1cd 100644 --- a/session.go +++ b/session.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "crypto/rand" @@ -79,6 +79,7 @@ func (mgr *MemorySessionManager) GarbageCollection() { time.AfterFunc(time.Duration(gcTimeInterval)*time.Second, mgr.GarbageCollection) } +// Count returns the number of the current session stored in the session manager. func (mgr *MemorySessionManager) Count() int { return len(mgr.sessions) } diff --git a/session_test.go b/session_test.go index 01e24d2..509a21e 100644 --- a/session_test.go +++ b/session_test.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "testing" @@ -73,3 +73,19 @@ func TestMemorySessionNotExpire(t *testing.T) { t.Errorf("Falsely recycled non-expired sessions.") } } + +func TestMemorySessionCount(t *testing.T) { + mgr := NewMemorySessionManager() + for i := 0; i < 100; i++ { + sid, _ := mgr.sessionID() + s := MemorySession{sid: sid, data: make(map[string]interface{}), createdAt: time.Now().AddDate(0, 0, -1)} + mgr.sessions[sid] = &s + } + if mgr.Count() != 100 { + t.Errorf("Could not correctly get session count: %v != %v", mgr.Count(), 100) + } + mgr.GarbageCollection() + if mgr.Count() != 0 { + t.Errorf("Could not correctly get session count after GC: %v != %v", mgr.Count(), 0) + } +} diff --git a/template.go b/template.go index 39b0da0..d3e1ca0 100644 --- a/template.go +++ b/template.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "fmt" diff --git a/template_test.go b/template_test.go index f3b94bc..e4df404 100644 --- a/template_test.go +++ b/template_test.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "bytes" diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..8cead84 --- /dev/null +++ b/tree.go @@ -0,0 +1,179 @@ +package golf + +import ( + "fmt" + "sort" + "strings" +) + +type node struct { + text string + names map[string]int + handler HandlerFunc + + parent *node + colon *node + + children nodes + start byte + max byte + indices []uint8 +} + +type nodes []*node + +func (s nodes) Len() int { + return len(s) +} + +func (s nodes) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s nodes) Less(i, j int) bool { + return s[i].text[0] < s[j].text[0] +} + +func (n *node) matchNode(path string) (*node, int8, int) { + + if path == ":" { + if n.colon == nil { + n.colon = &node{text: ":"} + } + return n.colon, 0, 0 + } + + for i, child := range n.children { + if child.text[0] == path[0] { + + maxLength := len(child.text) + pathLength := len(path) + var pathCompare int8 + + if pathLength > maxLength { + pathCompare = 1 + } else if pathLength < maxLength { + maxLength = pathLength + pathCompare = -1 + } + + for j := 0; j < maxLength; j++ { + if path[j] != child.text[j] { + ccNode := &node{text: path[0:j], children: nodes{child, &node{text: path[j:]}}} + child.text = child.text[j:] + n.children[i] = ccNode + return ccNode.children[1], 0, i + } + } + + return child, pathCompare, i + } + } + + return nil, 0, 0 +} + +func (n *node) addRoute(parts []string, names map[string]int, handler HandlerFunc) { + + var ( + tmpNode *node + currentNode *node + loop = true + ) + + currentNode, result, i := n.matchNode(parts[0]) + + for loop == true { + if currentNode == nil { + currentNode = &node{text: parts[0]} + n.children = append(n.children, currentNode) + } else if result == 1 { + // + parts[0] = parts[0][len(currentNode.text):] + tmpNode, result, i = currentNode.matchNode(parts[0]) + n = currentNode + currentNode = tmpNode + continue + } else if result == -1 { + tmpNode := &node{text: parts[0]} + currentNode.text = currentNode.text[len(tmpNode.text):] + tmpNode.children = nodes{currentNode} + n.children[i] = tmpNode + currentNode = tmpNode + } + break + } + + if len(parts) == 1 { + currentNode.handler = handler + currentNode.names = names + return + } + + currentNode.addRoute(parts[1:], names, handler) +} + +func (n *node) findRoute(urlPath string) (*node, error) { + + urlByte := urlPath[0] + pathLen := len(urlPath) + + if urlByte >= n.start && urlByte <= n.max { + if i := n.indices[urlByte-n.start]; i != 0 { + matched := n.children[i-1] + nodeLen := len(matched.text) + if nodeLen < pathLen { + if matched.text == urlPath[:nodeLen] { + if matched, _ := matched.findRoute(urlPath[nodeLen:]); matched != nil { + return matched, nil + } + } + } else if matched.text == urlPath { + return matched, nil + } + } + } + + if n.colon != nil && pathLen != 0 { + i := strings.IndexByte(urlPath, '/') + if i > 0 { + if cNode, err := n.colon.findRoute(urlPath[i:]); cNode != nil { + return cNode, err + } + } else if n.colon.handler != nil { + return n.colon, nil + } + } + + return nil, fmt.Errorf("Can not find route") +} + +func (n *node) optimizeRoutes() { + + if len(n.children) > 0 { + sort.Sort(n.children) + for i := 0; i < len(n.indices); i++ { + n.indices[i] = 0 + } + + n.start = n.children[0].text[0] + n.max = n.children[len(n.children)-1].text[0] + + for i := 0; i < len(n.children); i++ { + cNode := n.children[i] + cNode.parent = n + + cByte := int(cNode.text[0] - n.start) + if cByte >= len(n.indices) { + n.indices = append(n.indices, make([]uint8, cByte+1-len(n.indices))...) + } + n.indices[cByte] = uint8(i + 1) + cNode.optimizeRoutes() + } + } + + if n.colon != nil { + n.colon.parent = n + n.colon.optimizeRoutes() + } +} diff --git a/view.go b/view.go index 60b14e4..8a137bc 100644 --- a/view.go +++ b/view.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "bytes" diff --git a/view_test.go b/view_test.go index 8cba5ab..904f931 100644 --- a/view_test.go +++ b/view_test.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "testing" diff --git a/xsrf.go b/xsrf.go index dc5adc3..03d566d 100644 --- a/xsrf.go +++ b/xsrf.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "encoding/hex" diff --git a/xsrf_test.go b/xsrf_test.go index 887baed..862e758 100644 --- a/xsrf_test.go +++ b/xsrf_test.go @@ -1,4 +1,4 @@ -package Golf +package golf import ( "encoding/hex" @@ -19,7 +19,8 @@ func TestDecodeXSRF(t *testing.T) { maskedToken := newXSRFToken() maskedTokenBytes, _ := hex.DecodeString(maskedToken) mask, token, _ := decodeXSRFToken(maskedToken) - if !compareToken(maskedTokenBytes, append(mask, websocketMask(mask, token)...)) { - t.Errorf("Could not genearte correct XSRF token. %v != %v") + result := append(mask, websocketMask(mask, token)...) + if !compareToken(maskedTokenBytes, result) { + t.Errorf("Could not genearte correct XSRF token. %v != %v", maskedTokenBytes, result) } }